tfrecords是tensorflow自带的文件格式,也是一种二进制文件:
- 方便读取和移动
- 是为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中
- 文件格式:*.tfrecords
- 写如文件的内容:Example协议块,是一种类字典的格式
TFRecords存储的api
"""api
1、建立TFRecord存储器
tf.python_io.TFRecordWriter(path)
写入tfrecords文件
path: TFRecords文件的路径
return:写文件
method
write(record):向文件中写入一个字符串记录
close():关闭文件写入器
注:字符串为一个序列化的Example,Example.SerializeToString()
2、构造每个样本的Example协议块
tf.train.Example(features=None)
写入tfrecords文件
features:tf.train.Features类型的特征实例
return:example格式协议块
tf.train.Features(feature=None)
构建每个样本的信息键值对
feature:字典数据,key为要保存的名字,
value为tf.train.Feature实例
return:Features类型
tf.train.Feature(**options)
**options:例如
bytes_list=tf.train. BytesList(value=[Bytes])
int64_list=tf.train. Int64List(value=[Value])
tf.train. Int64List(value=[Value])
tf.train. BytesList(value=[Bytes])
tf.train. FloatList(value=[value])
同文件阅读器流程,中间需要解析过程
解析TFRecords的example协议内存块
tf.parse_single_example(serialized,features=None,name=None)
解析一个单一的Example原型
serialized:标量字符串Tensor,一个序列化的Example
features:dict字典数据,键为读取的名字,值为FixedLenFeature
return:一个键值对组成的字典,键为读取的名字
tf.FixedLenFeature(shape,dtype)
shape:输入数据的形状,一般不指定,为空列表
dtype:输入数据类型,与存储进文件的类型要一致
类型只能是float32,int64,string
"""
读取tfrecords的api与流程
"""api
同文件阅读器流程,中间需要解析过程
解析TFRecords的example协议内存块
tf.parse_single_example(serialized,features=None,name=None)
解析一个单一的Example原型
serialized:标量字符串Tensor,一个序列化的Example
features:dict字典数据,键为读取的名字,值为FixedLenFeature
return:一个键值对组成的字典,键为读取的名字
"""
"""流程
tf.FixedLenFeature(shape,dtype)
shape:输入数据的形状,一般不指定,为空列表
dtype:输入数据类型,与存储进文件的类型要一致
类型只能是float32,int64,string
1、构造TFRecords阅读器
2、解析Example
3、转换格式,bytes解码
"""
tfrecords文件的读取
import tensorflow as tf
# 定义cifar的数据等命令行参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("cifar_dir", "cifar-10-batches-py", "文件的目录")
tf.app.flags.DEFINE_string("cifar_tfrecords", "./tmp/cifar.tfrecords", "存进tfrecords的文件")
class CifarRead():
"""
完成读取二进制文件,写进tfrecords,读取tfrecords
"""
def __init__(self, filelist):
self.file_list = filelist # 文件列表
# 定义读取图片的一些属性
self.height = 32
self.width = 32
self.channel = 3
# 存储的字节
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self):
# 构造文件队列
file_queue = tf.train.string_input_producer(self.file_list)
# 构造二进制文件读取器,并指定读取长度
reader = tf.FixedLengthRecordReader(self.bytes)
key, value = reader.read(file_queue)
# 解码内容
print(value)
# 二进制文件的解码
label_image = tf.decode_raw(value, out_type=tf.uint8)
print(label_image)
# 分割图片和标签:特征值和目标值
label = tf.slice(label_image, [0], [self.label_bytes]) #读取标签
image = tf.slice(label_image, [self.label_bytes], [self.image_bytes]) #读取特征向量
print("label:", label)
print("image:", image)
# 对图片的特征数据进行形状的改变 [3072] --> [32, 32, 3]
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
print("image_reshape:", image_reshape)
# 批处理数据
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
print(image_batch, label_batch)
return image_batch, label_batch
def write_to_tfrecords(self, image_batch, label_batch):
"""
将图片的特征值和目标值存进tfrecords
:param image_batch: 10张图片的特征值
:param label_batch: 10张图片的目标值
:return: None
"""
#建立一个tfrecords存储器
writer = tf.python_io.TFRecordWriter(path=FLAGS.cifar_tfrecords) #注意:tf.python_io.TFRecordWriter已经被tf.io.TFRecordWriter代替
#循环将所有样本写入文件,每张图片样本都要构造一个example协议
for i in range(10):
#取出第i个图片数据的特征值和目标值
image = image_batch[i].eval().tostring() #.eval()获取值
label = label_batch[i].eval()[0] #因为是一个二维列表,所以必须取[0]
"""注意:eval必须写在session中"""
#构造一个样本的example
example = tf.train.Example(features=tf.train.Features(feature={
"image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
#写入单独的样本
writer.write(example.SerializeToString()) #序列化后再写入文件
#关闭
writer.close()
def read_from_tfrecords(self):
#构造文件阅读器
file_queue = tf.train.input_producer([FLAGS.cifar_tfrecords])
#构造文件阅读器,读取内容example
reader = tf.TFRecordReader()
key, value = reader.read(file_queue) #value也是一个example的序列化
#由于存储的是example,所以需要对example解析
features = tf.parse_single_example(value, features={
"image":tf.FixedLenFeature(shape=[], dtype=tf.string),
"label":tf.FixedLenFeature(shape=[], dtype=tf.int64)
})
print(features["image"], features["label"]) #注意:此时是tensor
#解码内容,如果读取的string类型,需要解码,如果是int64,float32就不需要解码。因为里面都是bytes,所以需要解码
image = tf.decode_raw(features["image"], tf.uint8)
label = tf.cast(features["label"], tf.int32) #label不需要解码,因为int64实际在存储的时候还是以int32存储的,不会占用那么多空间,所以这里可以直接转换成int32
print(image, label)
#固定图片的形状,以方便批处理
image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
print(image_reshape)
#进行批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
return image_batch, label_batch
import os
if __name__ == "__main__":
# 找到文件,放入列表 路径+名字 ->列表当中
file_name = os.listdir(FLAGS.cifar_dir)
file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if "0" <= file[-1] <= "9"]
print(file_list)
cf = CifarRead(file_list)
image_batch, label_batch = cf.read_and_decode()
# image_batch, label_batch = cf.read_from_tfrecords()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
print(sess.run([image_batch, label_batch]))
#存进tfrecords文件
print("开始存储...")
cf.write_to_tfrecords(image_batch, label_batch) #因为这个函数里面有eval,所以必须在session里面运行
print("结束存储...")
# print("读取的数据:\n",sess.run([image_batch, label_batch]))
coord.request_stop()
coord.join()
网友评论