美文网首页
MXNET笔记(二)准备数据

MXNET笔记(二)准备数据

作者: 学而时习之_不亦说乎 | 来源:发表于2017-02-10 12:37 被阅读2951次

    MXNET并不直接读入图像,而是读入其自定义的一种格式。为了生成这种格式,需要利用/mxnet/tools/im2rec.py工具来对数据库图像进行处理而生成。我现在手头没有现成的数据库可以使用,而现在一般的数据库又太大了,所以我把Rec格式的MNIST数据库还原成图像文件。

    #利用mxnet提供的代码下载MNIST数据库
    import numpy as np
    import os
    import urllib
    import gzip
    import struct
    def download_data(url, force_download=True): 
        fname = url.split("/")[-1]
        if force_download or not os.path.exists(fname):
            urllib.urlretrieve(url, fname)
        return fname
    
    def read_data(label_url, image_url):
        with gzip.open(download_data(label_url)) as flbl:
            magic, num = struct.unpack(">II", flbl.read(8))
            label = np.fromstring(flbl.read(), dtype=np.int8)
        with gzip.open(download_data(image_url), 'rb') as fimg:
            magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
            image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
        return (label, image)
    
    path='http://yann.lecun.com/exdb/mnist/'
    (train_lbl, train_img) = read_data(
        path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')
    (val_lbl, val_img) = read_data(
        path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')
    #---------------------------------------------------------------------------------------------#
    #保存图像需要
    import imageio
    #记录每一类样本的数目,并将其作为文件名
    file_count = [0]*10
    #总共60000个样本,循环保存
    for i in range(0,60000):
        #如果文件名不存在,创建
        if not os.path.exists(str(train_lbl[i])):
            os.makedirs(str(train_lbl[i]))
        #样本数目+1
        file_count[train_lbl[i]] = file_count[train_lbl[i]] + 1
        #获取文件保存的路径
        path = os.path.join(os.path.curdir,str(train_lbl[i])) 
        #生成文件名
        file_name = str(path + "/" + str(file_count[train_lbl[i]]) + ".jpg")
        #保存文件
        imageio.imwrite(file_name,train_img[i])
    

    经过上面的处理,我们可以得到十个文件夹,文件夹的文件名就是图像的label。我已经将打包好的数据上传到了CSDN,有需要的可以点击CSDN链接

    现在使用刚刚提取出来的数据,再将其还原为rec格式。根据后面的操作发现在进行数据转化的时候需要用到cv2模块,所以,首先用pip安装opencv

    pip install opencv-python
    
    

    在得到了原始数据,也安装了opencv module以后,就可以使用mxnet文件夹下/tools/im2re.py程序来完成:

    # 我已经将im2rec.py文件复制到了需要操作的文件夹下,这样方便一点
    python im2rec.py  --recursive=Ture --exts=.jpg --list=True MNIST MNIST
    # --recursive=Ture表示对图像文件夹下的所有文件进行递归操作
    # --exts=.jpg 表示图像文件的后缀为jpg
    # --list=True表示,首先生成一个list文件
    # MNIST表示生成的list文件的前缀,后缀默认为.lst
    # MNIST表示保存图像的文件夹
    

    值得注意的是这个im2re.py 文件存在一个Bug,主要是-exts这个选项这里需要修改一下:

    #Pay attention, different from the source code here
    cgroup.add_argument('--exts', type=list, action='append', default=['.jpeg', '.jpg'],
    #增加action = 'append',这样就可以接受新的参数
    

    如果我们打开生成的MNIST.lst文件,可以得到看到像下面的样子:

    647   0.000000  MNIST/0/1581.jpg
    21679   3.000000    MNIST/3/375.jpg
    39270   6.000000    MNIST/6/3927.jpg
    41692   6.000000    MNIST/6/780.jpg
    23150   3.000000    MNIST/3/5073.jpg
    44133   7.000000    MNIST/7/2978.jpg
    32353   5.000000    MNIST/5/2580.jpg
    
    

    每行最后一项是某图像的位置,倒数第二项是对应的label,也就是该类文件夹的名字,第一项是该文件对于的一个index。得到这list以后,我们可以进一步生成rec文件了。根据im2rec.py文件的源代码,这一过程可以使用多线程也可以使用单线程,但是多线程往往会出错,我的办法是直接将多线程的部分注释掉,直接用单线程处理,修改im2rec.pymain函数部分如下:

    if __name__ == '__main__':
        args = parse_args()
        if args.list:
            make_list(args)
        else:
            if os.path.isdir(args.prefix):
                working_dir = args.prefix
            else:
                working_dir = os.path.dirname(args.prefix)
            files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)
                        if os.path.isfile(os.path.join(working_dir, fname))]
            count = 0
            for fname in files:
                if fname.startswith(args.prefix) and fname.endswith('.lst'):
                    print('Creating .rec file from', fname, 'in', working_dir)
                    count += 1
                    image_list = read_list(fname)
                    # -- write_record -- #
                    #if args.num_thread > 1 and multiprocessing is not None:
                    #    q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
                    #    q_out = multiprocessing.Queue(1024)
                    #    read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
                    #                    for i in range(args.num_thread)]
                    #    for p in read_process:
                    #        p.start()
                    #    write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir))
                    #    write_process.start()
    
                    #    for i, item in enumerate(image_list):
                    #        q_in[i % len(q_in)].put((i, item))
                    #    for q in q_in:
                    #        q.put(None)
                    #    for p in read_process:
                    #        p.join()
    
                    #    q_out.put(None)
                    #    write_process.join()
                    #else:
                    #print('multiprocessing not available, fall back to single threaded encoding')
                    import Queue
                    q_out = Queue.Queue()
                    fname = os.path.basename(fname)
                    fname_rec = os.path.splitext(fname)[0] + '.rec'
                    fname_idx = os.path.splitext(fname)[0] + '.idx'
                    record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
                                                           os.path.join(working_dir, fname_rec), 'w')
                    cnt = 0
                    pre_time = time.time()
                    for i, item in enumerate(image_list):
                        #打印文件列表
                        print( "current i is " + str(i))
                        print( "current item is " + str(item))
                        image_encode(args, i, item, q_out)
                        if q_out.empty():
                            continue
                        _, s, _ = q_out.get()
                        record.write_idx(item[0], s)
                        if cnt % 1000 == 0:
                            cur_time = time.time()
                            print('time:', cur_time - pre_time, ' count:', cnt)
                            pre_time = cur_time
                        cnt += 1
            if not count:
                print('Did not find and list file with prefix %s'%args.prefix)
    

    现在再次使用im2rec.py文件来生成rec文件:

    python im2rec.py  MNIST.lst  MNIST
    #  MNIST.lst是对应的list文件
    # MNIST 是最后保存rec文件的位置
    

    最后,在当前文件夹下就生成了MNIST.rec文件。
    如果Opencv的版本为3.x的话,可能会出现下面的错误,只要将Opencv的版本退回到2.x版本就可以了

    Segmentation fault (core dumped)
    

    相关文章

      网友评论

          本文标题:MXNET笔记(二)准备数据

          本文链接:https://www.haomeiwen.com/subject/xtalittx.html