加载数据
TensorFlow 作为符号编程框架,需要先构建数据流图,再读取数据,随后进行模型训练。
- 预加载数据(preloaded data):在 TensorFlow 图中定义常量或变量来保存所有数据。这种方式的缺点在于,将数据直接嵌在数据流图中,当训练数据较大时,很消耗内存。
- 填充数据(feeding): 使用 sess.run()中的 feed_dict 参数,将 Python 产生的数据填充给后端。Python 产生数据,再把数据填充后端。填充的方式也有数据量大、消耗内存等缺点。
- 从文件读取数据(reading from file):从文件中直接读取,让队列管理器从文件中读取数据。这是最推荐的方式,让 TensorFlow 自己从文件中读取数据,并解码成可使用的样本集。
import tensorflow as tf
# 第二种方式:填充数据
a1 = tf.placeholder(tf.int16)
a2 = tf.placeholder(tf.int16)
b = tf.add(x1, x2)
# 用 Python 产生数据
li1 = [2, 3, 4]
li2 = [4, 0, 1]
# 打开一个会话,将数据填充给后端
with tf.Session() as sess:
print sess.run(b, feed_dict={a1: li1, a2: li2})
TFRecords 是一种二进制文件,能更好地利用内存,更方便地复制和移动,并且不需要单独的标记文件。
从文件读取数据分为如下两个步骤:
(1)把样本数据写入 TFRecords 二进制文件;
(2)再从队列中读取。
把样本数据写入 TFRecords 二进制文件
- 将数据填入到 tf.train.Example 的协议缓冲区(protocolbuffer)中
example=tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[i].tolist)),
'image_raw': _bytes_feature(image_raw)
}))
- 将协议缓冲区序列化为一个字符串,通过 tf.python_io.TFRecordWriter 写入 TFRecords文件
#定义一个writer
filename=os.path.join(os.getcwd(),name+'.tfrecords')
writer= tf.python_io.TFRecordWriter(filename)
......
#对于for i in range(num_example)中的每个example,写入文件
writer.write(example.SerializerToString())
- 最后关闭writer
writer.close()
从队列中读取
一旦生成了 TFRecords 文件,接下来就可以使用队列读取数据了。主要分为 3 步:
(1)创建张量,从二进制文件读取一个样本;
(2)创建张量,从二进制文件随机读取一个 mini-batch;
(3)把每一批张量传入网络作为输入节点。
网友评论