TensorFlow的数据读取机制:(读取图片为例)
如果将图片先读取到内存中后提供给GPU或CPU计算,这样GPU在数据读取的时间是无事可做的,这大大降低运算效率。为此将读取数据和计算分别放在两个线程中去做。一个线程负责源源不断的将图片读取到内存的一个队列中,另一个线程直接从队列中取用计算。
机器学习通常使用epoch来重复计算,运行一个epoch就是将数据集中的所有图片都计算一遍,两个epoch计算两遍。为了方便管理epoch,TensorFlow在内存队列前又添加了一个队列叫做‘文件名队列’。使用 tf.train.string_input_producer函数来创建文件名队列,需要传入一个文件名 list,系统会自动将它转换成一个文件名队列。此外还需要num_epoch(epoch的 数目)和shuffle(在文件名队列中打乱图片顺序)两个参数。在TensorFlow中不需要自己创建内存队列,使用reader对象从文件名中读取数据即可,如下所示。
reader = tf.WholeFileReader key, value = reader.read(filename_queue)#filename_queue是文件名队列。
一切准备就绪之后,咱们定义的文件名队列中并没有开始读入图片,需要使用tf.train.star_queue_runners函数进行激活。如下是读取三张图片的程序:
with tf.Session() as sess:
filename = ['A.jpg', 'B.jpg', 'C.jpg']
filename_queue = tf.train.string_input_producer(filename, huffle=false,num_epoch=5)
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
tf.local_variables_initializer().run()
threads = tf.train.start_queue_runners(sess=sess)
i = 0
while True :
i+=1
image_data = sess.run(value)
with open('read/test_%d.jpg'%i, 'wb') as f:
f.write(image_data)
网友评论