tensorflow有几种读取数据的方式,最常见的使用python普通加载,加载进内存,再传给模型。如下所示:
# . Load data
data = np.load('example/example.npz')
_x, _y = data["_x"], data["_y"]
#Q1. Make a placeholder for x such that it should be of dtype=int32, shape=(None, 9).
# Inputs and targets
x_pl = tf.placeholder(tf.int32, shape=(None,9))
y_hat = 45 - tf.reduce_sum(x_pl, axis=1) # We find a digit x_pl doesn't contain.
# Session
with tf.Session() as sess:
_y_hat = sess.run(y_hat, {x_pl: _x})
print("y_hat =", _y_hat[:30])
print("true y =", _y[:30])
但是如果数据量较大,加载进内存过于占内存,影响速度。所以这时最好使用tensorflow提供的接口来读取训练数据
TFRecord使用
TFRecord文件在tensorflow中可以快速复制,移动,读取,存储。在我理解来看,tfrecord文件里的内容格式是tensorflow自定义的一个protobuffer。tensorflow提供了一个tf.train.Example接口,可以将写入数据填充到Example里,然后序列化成一个字符串,然后通过tf.python_io.TFRecordWriter写入本地文件
1)序列化
# Serialize
with tf.python_io.TFRecordWriter("example/tfrecord") as fout:
for _xx, _yy in zip(_x, _y):
ex = tf.train.Example()
# 注意_x, _y输入得是一个列表
ex.features.feature['x'] = tf.train.Feature(int64_list=tf.train.Int64List(value=_x))
ex.features.feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=_y))
fout.write(ex.SerializeToString())
或者
example = tf.train.Example(features=tf.train.Features(
feature={
'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
}))
2)读取tfrecord文件
主要分成3个步骤:
1)生成一个解析队列tf.train.string_input_producer
2)tf.TFRecordReader读取解析队列,返回serialized_example对象
3 tf.parse_single_example操作将Example协议缓冲区(protocol buffer)解析为张量
读取TFReCord文件的流程如下:
def read_and_decode_single_example(fname):
# Create a string queue
fname_q = tf.train.string_input_producer([fname], num_epochs=1, shuffle=True)
# Q3. Create a TFRecordReader
reader = tf.TFRecordReader()
# Read the string queue
_, serialized_example = reader.read(fname_q)
# Q4. Describe parsing syntax
features = tf.parse_single_example(
serialized_example,
features={'x': tf.FixedLenFeature([9], tf.int64),
'y': tf.FixedLenFeature([1], tf.int64)}
)
# Output
x = features['x']
y = features['y']
return x, y
# Ops
x, y = read_and_decode_single_example('example/tfrecord')
y_hat = 45 - tf.reduce_sum(x)
# Session
with tf.Session() as sess:
#Q5. Initialize local variables
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
_y, _y_hat = sess.run([y, y_hat])
print(_y[0],"==", _y_hat, end="; ")
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
可以看到通过parse_single_example将本地文件读取出来,同时利用start_queue_runners启动输入管道的线程,开启了多个队列线程来读取数据,将数据读入队列。如果不启动,队列是空的,会出现了无限等待,需要启动QueueRunner来填充队列。tf.FixedLenFeature()中指明数据维度和数据类型。
tf.train.start_queue_runners(sess=sess, coord=coord)
Starts all queue runners collected in the graph.
This is a companion method to add_queue_runner(). It just starts threads for all queue runners collected in the graph. It returns the list of all threads.
每个线程使用前应判断coord.should_stop()。如果调用了 coord.request_stop() ,coord.should_stop() 则返回true 。在程序的最后是用coord.join(threads) 等待所有线程结束。
tf.train.Coordinator()
A coordinator for threads.
This class implements a simple mechanism to coordinate the termination of a set of threads.
Any of the threads can call coord.request_stop() to ask for all the threads to stop.
值得注意的是,tf.train.string_input_producer中若num_epoches=None, 将循环读取文件,不会停止。若指定num_epoches为一个整数,则生成了一个local varibale。需在代码中使用tf.local_variables_initializer()来初始化local variable,如代码中所见。
tf.train.slice_input_producer用法与tf.train.string_input_producer类似,可以直接对tensor list切片,生成数据供后面使用。
如果是读取csv
with open('example/example.csv', 'w') as fout:
fout.write(_x_str)
# Hyperparams
batch_size = 10
# Create a string queue
fname_q = tf.train.string_input_producer(["example/example.csv"])
# Q8. Create a TextLineReader
reader = tf.TextLineReader()
# Read the string queue
_, value = reader.read(fname_q)
# Q9. Decode value
record_defaults = [[0]]*10
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = tf.decode_csv(value, record_defaults=record_defaults,)
x = tf.stack([col1, col2, col3, col4, col5, col6, col7, col8, col9])
y = col10
如果想读取图像:
# Make fake images and save
for i in range(100):
_x = np.random.randint(0, 256, size=(10, 10, 4))
plt.imsave("example/image_{}.jpg".format(i), _x)
# Import jpg files
images = tf.train.match_filenames_once('example/*.jpg')
# Create a string queue
fname_q = tf.train.string_input_producer(images, num_epochs=num_epochs, shuffle=True)
# Q10. Create a WholeFileReader
reader = tf.WholeFileReader()
# Read the string queue
_, value = reader.read(fname_q)
# Q11. Decode value
img = tf.image.decode_image(value,channels=4)
# Batching
img_batch = tf.train.batch([img], shapes=([10, 10, 4]), batch_size=batch_size)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
num_samples = 0
try:
while not coord.should_stop():
sess.run(img_batch)
num_samples += batch_size
print(num_samples, "samples have been seen")
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
网友评论