一 : 选择tfrecord的好处
对于数据量大的时候可以采用多线程处理数据,例如可以一个线程处理数据,另一个做训练
tfrecord可以分成三部分:encode、decode、run batch,encode批量将数据/图片存储成字典形式,decode和tf.train.shuffle_batch或tf.train.batch配合这里注意每次取出的是一个数据的tensor,通过epoch和sess.run的迭代过程来完成数据的批量处理。
二:encode 参考代码如下
recordfilenum = 0
img_raw = img.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))
writer.write(example.SerializeToString())
if count % mumber ==0 and count!=0:
writer.close()
recordfilenum = recordfilenum +1
tfrecordfilename = ("traindata.tfrecords-%.2d" % recordfilenum)
writer = tf.python_io.TFRecordWriter( path+tfrecordfilename )
[注:这里是 mumber个数据写入到一个tfrecord文件中。path为存储路径。
还要注意不同数据处理方式有所不同(参考:http://zhangzhenyuan.lofter.com/post/e1458_10a6f295)。大文件可以选择多个tfrecord文件分批处理数据]
例如对于np数组
c = np.array([[0, 1, 2],[3, 4, 5]])
c = c.astype(np.uint8)
c_raw = c.tostring()
解码时
c_out = tf.decode_raw(c_raw_out, tf.uint8)
c_out = tf.reshape(c_out, [2, 3])【其中example也要注意改成对应byteslist形式】
三:decode参考代码如下
filename_queue = '/存储路径/traindata.tfrecords-*'
files = tf.train.match_filenames_once(filename_queue)
filename_qu = tf.train.string_input_producer( files,shuffle = False,num_epochs = 3)
reader = tf.TFRecordReader()
_, serialized_example = reader.read( filename_qu)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['img_raw'],tf.float64)
image = tf.reshape(image, [224,224,3])
batch_image,batch_imi , batch_label = tf.train.shuffle_batch([image,imis ,labels],
batch_size=batch_size,
num_threads=2, capacity=1000 + 3 *
batch_size,
min_after_dequeue=1000)
【注:很多np数组是解码成uint8(bytes)格式的,但如果你数据很小,比如我的数据就做过归一化处理,这样接近零的数据一解码就变成了零,会发生数据丢失现象,所以这里直接解码成float格式,这里要注意。
最后返回的是batch个tensor数据,可以在loop过程中通过sess.run()直接运行生成结果;还需要注意的是在tf.train.shuffle_batch前一定要确定数据形状,不然可能出现bug】
网友评论