美文网首页
29-Tfrecords文件的读取与存储

29-Tfrecords文件的读取与存储

作者: jxvl假装 | 来源:发表于2019-10-05 11:17 被阅读0次

    tfrecords是tensorflow自带的文件格式,也是一种二进制文件:

    1. 方便读取和移动
    2. 是为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中
    3. 文件格式:*.tfrecords
    4. 写如文件的内容:Example协议块,是一种类字典的格式

    TFRecords存储的api

    """api
    1、建立TFRecord存储器
    tf.python_io.TFRecordWriter(path)
        写入tfrecords文件
        path: TFRecords文件的路径
        return:写文件
    method
        write(record):向文件中写入一个字符串记录
        close():关闭文件写入器
    注:字符串为一个序列化的Example,Example.SerializeToString()
    
    2、构造每个样本的Example协议块
    tf.train.Example(features=None)
        写入tfrecords文件
        features:tf.train.Features类型的特征实例
        return:example格式协议块
    
    tf.train.Features(feature=None)
        构建每个样本的信息键值对
        feature:字典数据,key为要保存的名字,
        value为tf.train.Feature实例
        return:Features类型
    
    tf.train.Feature(**options)
        **options:例如
        bytes_list=tf.train. BytesList(value=[Bytes])
        int64_list=tf.train. Int64List(value=[Value])
        tf.train. Int64List(value=[Value])
        tf.train. BytesList(value=[Bytes])
        tf.train. FloatList(value=[value])
    
    同文件阅读器流程,中间需要解析过程
    
    解析TFRecords的example协议内存块
    tf.parse_single_example(serialized,features=None,name=None)
        解析一个单一的Example原型
        serialized:标量字符串Tensor,一个序列化的Example
        features:dict字典数据,键为读取的名字,值为FixedLenFeature
        return:一个键值对组成的字典,键为读取的名字
    
    tf.FixedLenFeature(shape,dtype)
        shape:输入数据的形状,一般不指定,为空列表
        dtype:输入数据类型,与存储进文件的类型要一致
        类型只能是float32,int64,string
    
    """
    

    读取tfrecords的api与流程

    """api
    同文件阅读器流程,中间需要解析过程
    
    解析TFRecords的example协议内存块
    tf.parse_single_example(serialized,features=None,name=None)
        解析一个单一的Example原型
        serialized:标量字符串Tensor,一个序列化的Example
        features:dict字典数据,键为读取的名字,值为FixedLenFeature
        return:一个键值对组成的字典,键为读取的名字
    """
    """流程
    tf.FixedLenFeature(shape,dtype)
        shape:输入数据的形状,一般不指定,为空列表
        dtype:输入数据类型,与存储进文件的类型要一致
        类型只能是float32,int64,string
    
    
    1、构造TFRecords阅读器
    2、解析Example
    3、转换格式,bytes解码
    """
    

    tfrecords文件的读取

    import tensorflow as tf
    
    # 定义cifar的数据等命令行参数
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_string("cifar_dir", "cifar-10-batches-py", "文件的目录")
    tf.app.flags.DEFINE_string("cifar_tfrecords", "./tmp/cifar.tfrecords", "存进tfrecords的文件")
    
    class CifarRead():
        """
        完成读取二进制文件,写进tfrecords,读取tfrecords
        """
    
        def __init__(self, filelist):
            self.file_list = filelist  # 文件列表
            # 定义读取图片的一些属性
            self.height = 32
            self.width = 32
            self.channel = 3
            # 存储的字节
            self.label_bytes = 1
            self.image_bytes = self.height * self.width * self.channel
            self.bytes = self.label_bytes + self.image_bytes
    
        def read_and_decode(self):
            # 构造文件队列
            file_queue = tf.train.string_input_producer(self.file_list)
            # 构造二进制文件读取器,并指定读取长度
            reader = tf.FixedLengthRecordReader(self.bytes)
            key, value = reader.read(file_queue)
            # 解码内容
            print(value)
            # 二进制文件的解码
            label_image = tf.decode_raw(value, out_type=tf.uint8)
            print(label_image)
            # 分割图片和标签:特征值和目标值
            label = tf.slice(label_image, [0], [self.label_bytes])  #读取标签
            image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])  #读取特征向量
            print("label:", label)
            print("image:", image)
            # 对图片的特征数据进行形状的改变 [3072] --> [32, 32, 3]
            image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
            print("image_reshape:", image_reshape)
    
            # 批处理数据
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
            print(image_batch, label_batch)
            return image_batch, label_batch
    
        def write_to_tfrecords(self, image_batch, label_batch):
            """
            将图片的特征值和目标值存进tfrecords
            :param image_batch: 10张图片的特征值
            :param label_batch: 10张图片的目标值
            :return: None
            """
            #建立一个tfrecords存储器
            writer = tf.python_io.TFRecordWriter(path=FLAGS.cifar_tfrecords)  #注意:tf.python_io.TFRecordWriter已经被tf.io.TFRecordWriter代替
            #循环将所有样本写入文件,每张图片样本都要构造一个example协议
            for i in range(10):
                #取出第i个图片数据的特征值和目标值
                image = image_batch[i].eval().tostring()   #.eval()获取值
                label = label_batch[i].eval()[0]    #因为是一个二维列表,所以必须取[0]
                """注意:eval必须写在session中"""
                #构造一个样本的example
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                }))
                #写入单独的样本
                writer.write(example.SerializeToString())   #序列化后再写入文件
            #关闭
            writer.close()
    
        def read_from_tfrecords(self):
            #构造文件阅读器
            file_queue = tf.train.input_producer([FLAGS.cifar_tfrecords])
            #构造文件阅读器,读取内容example
            reader = tf.TFRecordReader()
            key, value = reader.read(file_queue)  #value也是一个example的序列化
            #由于存储的是example,所以需要对example解析
            features = tf.parse_single_example(value, features={
                "image":tf.FixedLenFeature(shape=[], dtype=tf.string),
                "label":tf.FixedLenFeature(shape=[], dtype=tf.int64)
            })
            print(features["image"], features["label"]) #注意:此时是tensor
            #解码内容,如果读取的string类型,需要解码,如果是int64,float32就不需要解码。因为里面都是bytes,所以需要解码
            image = tf.decode_raw(features["image"], tf.uint8)
            label = tf.cast(features["label"], tf.int32)   #label不需要解码,因为int64实际在存储的时候还是以int32存储的,不会占用那么多空间,所以这里可以直接转换成int32
            print(image, label)
            #固定图片的形状,以方便批处理
            image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
            print(image_reshape)
            #进行批处理
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
            return image_batch, label_batch
    import os
    
    if __name__ == "__main__":
        # 找到文件,放入列表 路径+名字 ->列表当中
        file_name = os.listdir(FLAGS.cifar_dir)
        file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if "0" <= file[-1] <= "9"]
        print(file_list)
        cf = CifarRead(file_list)
        image_batch, label_batch = cf.read_and_decode()
        # image_batch, label_batch = cf.read_from_tfrecords()
        with tf.Session() as sess:
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess, coord=coord)
            print(sess.run([image_batch, label_batch]))
            #存进tfrecords文件
            print("开始存储...")
            cf.write_to_tfrecords(image_batch, label_batch) #因为这个函数里面有eval,所以必须在session里面运行
            print("结束存储...")
            # print("读取的数据:\n",sess.run([image_batch, label_batch]))
            coord.request_stop()
            coord.join()
    
    

    相关文章

      网友评论

          本文标题:29-Tfrecords文件的读取与存储

          本文链接:https://www.haomeiwen.com/subject/ltnbuctx.html