写
-
思路
数据移到example, example序列化为字符串后,再写入文件。example
包含Features
, Features
包含Feature字典
,Feature字典
的value
中里包含有一个 FloatList
, 或者ByteList
,或者Int64List
, 数据就写在参数value=[]
的列表里。
1. 定义writer
`writer = tf.python_io.TFRecordWriter(tfrecord路径)`
2. 定义example(相当搞人!!)
```
# img转变为bytes
img = Image.open(img_path)
img = img.resize((224, 224))
img_raw = img.tobytes()
example=tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
"img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
```
3. 序列化为字符串后写入tfrecord
writer.write(example.SerializeToString())
读
-
队列读取
简明示意图:
-
计算图
1. 文件名队列
filename_queue = tf.train.string_input_producer([filename])
2. 构造Reader
reader = tf.TFRecordReader()
3. 从队列中读序列化数据
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
4. 解析单个example的features
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
5. 从features 字典拿到可用数据
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [224, 224, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features['label'], tf.int32)
6. 获得batch
img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=2000, min_after_dequeue=1000)
-
运行
打开sess
后,
1. 首先初始化graph
sess.run(tf.initialize_all_variables())
1. 启动队列
threads = tf.train.start_queue_runners(sess=sess)
2. sess.run()
img_batch, label_batch = sess.run([img_batch, label_batch]
网友评论