美文网首页深度学习
tfrecords读取文件

tfrecords读取文件

作者: Kevin__ding | 来源:发表于2018-10-17 19:19 被阅读0次

    以前做科研论文的时候, 所使用的音频数据比较少, 所以都是直接读进内存中在feeding给placeholder。现在在做一些偏工程的项目时就发现远远不行了,feeding训练速度远远提不上来。所以这两周都在为训练提速而折磨。在此记录下来尝试的方式。
    tensorflow推荐使用tfrecords来存储数据, 这样能加快数据的读取。

    def convert_to_tfrecord(loader):
        ''' modefy batch_size=1 in './conf/train_ce.conf' before convert to tfrecord data format '''
        def write_tfrecords(queue, i):
            start_time = time.time()
            while queue.empty():
                if time.time()-start_time > 600:               #超时队列中还没有数据该进程就退出
                    print('wait timeout! proc %d exit!'%i)
                    exit()
                time.sleep(1)
            writer = tf.python_io.TFRecordWriter('./train_input/tfrecords_file/train_dataset_%d.tfrecords'%i)
            while queue.qsize():
                batch = queue.get()                            # 从队列中获取一个样本
                example = tf.train.Example(features=tf.train.Features(feature={
                    'feature':  tf.train.Feature(float_list=tf.train.FloatList(value=batch[0].flatten())),
                    'label':    tf.train.Feature(int64_list=tf.train.Int64List(value=batch[1].flatten())),
                    'mask':     tf.train.Feature(int64_list=tf.train.Int64List(value=batch[2].flatten())),
                    'length':   tf.train.Feature(int64_list=tf.train.Int64List(value=[batch[3][0][0]]))
                    #'feature_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(batch[0].shape)))
                }))                                            # 这里将二维的feature label mask 转为一维的进行存储
                writer.write(example.SerializeToString())
            writer.close()
    
        start = time.time()
        queue = Queue(512)
        proc_record = []
        for i in range(10):
            p = Process(target=write_tfrecords, args=(queue, i)) #开10个进程用来写入数据
            p.start()
            proc_record.append(p)
        num = 0
        while True:
            try:
                batch = loader.next()                            # 获取一个样本, 压入队列
            except StopIteration:
                tf.logging.info('finished convert to tfrecords')
                break
            if batch is not None:
                queue.put(batch)
                num += 1
            else:
                break
        for p in proc_record:   p.join()                        # 等待所有进程结束
        print('num:', num)
        print('time:', time.time()-start)
    

    程序写了一个多进程写入tfrecords, 在主进程中读取数据压入队列,再开辟10个进程从队列中读取数据, 因为我的loader.next加载数据比较长,所以在子进程中设置了循环等待。

    在尝试过多线程, 应为python GIL的原因, 所以速度没有提升, 改成了多进程。

    相关文章

      网友评论

        本文标题:tfrecords读取文件

        本文链接:https://www.haomeiwen.com/subject/sfjkzftx.html