美文网首页
在TensorFlow中使用pipeline加载数据

在TensorFlow中使用pipeline加载数据

作者: hzyido | 来源:发表于2017-08-06 15:29 被阅读445次

前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示:


首先,A、B、C三个文件通过RandomShuffle进程被随机加载到FilenameQueue里,然后Reader1和Reader2进程同FilenameQueue里取文件名读取文件,读取的内容再被放到ExampleQueue里。最后,计算进程会从ExampleQueue里取数据。各个进程独立操作,互不影响,这样可以加快程序速度。
我们简单地生成3个样本文件。

#生成三个样本文件,每个文件包含5列,假设前4列为特征,最后1列为标签
data = np.zeros([20,5])
np.savetxt('file0.csv', data, fmt='%d', delimiter=',')
data += 1
np.savetxt('file1.csv', data, fmt='%d', delimiter=',')
data += 1
np.savetxt('file2.csv', data, fmt='%d', delimiter=',')

然后,创建pipeline数据流。

#定义FilenameQueue
filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])
#定义ExampleQueue
example_queue = tf.RandomShuffleQueue(
    capacity=1000,
    min_after_dequeue=0,
    dtypes=[tf.int32,tf.int32],
    shapes=[[4],[1]]
)
#读取CSV文件,每次读一行
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
#对一行数据进行解码
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
#将特征和标签push进ExampleQueue
enq_op = example_queue.enqueue([features, [col5]])
#使用QueueRunner创建两个进程加载数据到ExampleQueue
qr = tf.train.QueueRunner(example_queue, [enq_op]*2)
#使用此方法方便后面tf.train.start_queue_runner统一开始进程
tf.train.add_queue_runner(qr)
xs = example_queue.dequeue()
with tf.Session() as sess:
    coord = tf.train.Coordinator()
#开始所有进程
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(200):
        x = sess.run(xs)
        print(x)
    coord.request_stop()
    coord.join(threads)

以上我们采用for循环step_num次来控制训练迭代次数。我们也可以通过tf.train.string_input_producer的num_epochs参数来设置FilenameQueue循环次数来控制训练,当达到num_epochs时,TensorFlow会抛出OutOfRangeError异常,通过捕获该异常,停止训练。

filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6)
...
with tf.Session() as sess:
    sess.run(tf.initialize_local_variables()) #必须加上这句话,否则报错!
    coord = tf.train.Coordinator()
#开始所有进程
    threads = tf.train.start_queue_runners(coord=coord)
    try:
        while not coord.should_stop():
            x = sess.run(xs)
            print(x)
    except tf.errors.OutOfRangeError:
        print('Done training -- epch limit reached')
    finally:
        coord.request_stop()

捕获到异常时,请求结束所有进程。

原文: 在TensorFlow中使用pipeline加载数据

相关文章

  • 在TensorFlow中使用pipeline加载数据

    前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下...

  • 1. Tensorflow实战学习:TFRecord读取数据

    数据读取的基本方式参见CS 20SI 的Input Pipeline部分,Tensorflow主要有两种加载数据的...

  • tf_record 的几种生成方法。

    背景:大家在使用tensorflow 训练model 的时候,如何更好更快的加载数据,tensorflow官方给出...

  • [Tensorflow2] 数据加载

    针对小型常用数据集,tensorflow2中加载数据通常有两种方法:1、使用keras.datasets 有几种数...

  • 3·深入MNIST

    复习上一小节 加载MNIST数据 运行TensorFlow的InteractiveSession 这里,我们使用更...

  • tensorflow读取数据

    tensorflow有几种读取数据的方式,最常见的使用python普通加载,加载进内存,再传给模型。如下所示: 但...

  • tensorflow教程2:数据读取

    Tensorflow的数据读取有三种方式: Preloaded data: 预加载数据,也就是TensorFlow...

  • Apache Beam 处理文件

    今天我们介绍了如何使用pipeline在 Apache Beam 中的文件中读取、写入数据,其中“Employee...

  • ChAMP包学习(2)

    ChAMP Pipeline 1. 加载数据 加载数据始终是第一步。ChAMP提供了一个加载函数,用于从.idat...

  • Tensorflow初识

    1. 初识tensorflow tensorflow中需要明白的几点:使用tensor表示数据使用图来表示计算任务...

网友评论

      本文标题:在TensorFlow中使用pipeline加载数据

      本文链接:https://www.haomeiwen.com/subject/ajzglxtx.html