美文网首页
TensorFlow自学第4篇——文件读写

TensorFlow自学第4篇——文件读写

作者: 锅底一盆面 | 来源:发表于2019-01-18 22:38 被阅读0次

    这是一沟绝望的死水,清风吹不起半点漪沦。
    不如多扔些破铜烂铁,爽性泼你的剩菜残羹。
    也许铜的要绿成翡翠,铁罐上锈出几瓣桃花;
    再让油腻织一层罗绮,霉菌给他蒸出些云霞。
    让死水酵成一沟绿酒,漂满了珍珠似的白沫;
    小珠们笑声变成大珠,又被偷酒的花蚊咬破。
    那么一沟绝望的死水,也就夸得上几分鲜明。
    如果青蛙耐不住寂寞,又算死水叫出了歌声。
    这是一沟绝望的死水,这里断不是美的所在。
    不如让给丑恶来开垦,看它造出个什么世界。
    ——朱自清《死水》

    昨天没有更新,并不是因为懒惰,只是随着渐渐深入,所学略感吃力,所以干脆把文件读写一并总结。好像不管csv、图像、二进制、tfrecords,大致都是“构造文件队列—构造阅读器—读取—解码—处理—批处理”等一套流程。文件读取以及读取数据处理成张量的过程如下所示。


    图像

    在图像数字化表示当中,分为黑白和彩色两种。在数字化表示图片的时候,有三个因素。分别是图片的长、图片的宽、图片的颜色通道数。那么黑白图片的颜色通道数为1,它只需要一个数字就可以表示一个像素位;而彩色照片就不一样了,它有三个颜色通道,分别为RGB,通过三个数字表示一个像素位。TensorFlow支持JPG、PNG图像格式,RGB、RGBA颜色空间。图像用与图像尺寸相同(heightwidthchnanel)张量表示。图像所有像素存在磁盘文件,需要被加载到内存。

    大尺寸图像输入占用大量系统内存。训练CNN需要大量时间,加载大文件增加更多训练时间,也难存放多数系统GPU显存。大尺寸图像大量无关本征属性信息,影响模型泛化能力。最好在预处理阶段完成图像操作,缩小、裁剪、缩放、灰度调整等。图像加载后,翻转、扭曲,使输入网络训练信息多样化,缓解过拟合。Python图像处理框架PIL、OpenCV。TensorFlow提供部分图像处理方法。

    tf.image.resize_images //压缩图片导致定大小
    

    同样图像加载与二进制文件相同。图像需要解码。输入生成器(tf.train.string_input_producer)找到所需文件,加载到队列。tf.WholeFileReader 加载完整图像文件到内存,WholeFileReader.read 读取图像,tf.image.decode_jpeg 解码JPEG格式图像。图像是三阶张量。RGB值是一阶张量。加载图像格 式为[batch_size,image_height,image_width,channels]。批数据图像过大过多,占用内存过高,系统会停止响应。直接加载TFRecord文件,可以节省训练时间。支持写入多个样本。

    对图像而言,输入是像素,即特征值(一个像素是一个特征值)。单通道的一个像素点是一个特征值,三通道的像素点是三个特征值。图像数字化三要素:长度、宽度、通道数——[height, width, channel]。训练时,必须保证每个样本的特征值数量一致,即所有图片的像素值一样;尽量减少数据量。图片存储用uint8格式,节约空间,图片矩阵计算用float32格式,提高精度。

    二进制文件

    二进制文件数据集“CIFAR-10”中一行的格式:1字节标签+1024字节R+1024字节G+1024字节B,每个文件包含10000个这样3073个字节的行图像。

    标准TensorFlow格式

    TensorFlow提供了一种内置文件格式TFRecord,二进制数据和训练类别标签数据存储在同一文件。模型训练前图像等文本信息转换为TFRecord格式。TFRecord文件是protobuf格式。数据不压缩,可快速加载到内存。TFRecords文件包含 tf.train.Example protobuf,需要将Example填充到协议缓冲区,将协议缓冲区序列化为字符串,然后使用该文件将该字符串写入TFRecords文件。
    TensorFlow自带的文件格式,方便读取和移动;Example协议块,类字典格式
    写入文件时,对每一个样本都要构造Example协议块。

    #文件队列生成函数
    tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None) //产生指定文件张量
    #文件阅读器类
    tf.TextLineReader //阅读文本文件逗号分隔值(CSV)格式
    tf.FixedLengthRecordReader //要读取每个记录是固定数量字节的二进制文件
    tf.TFRecordReader //读取TfRecords文件
    # 解码:由于从文件中读取的是字符串,需要函数去解析这些字符串到张量
    tf.decode_csv(records,record_defaults,field_delim = None,name = None)//将CSV转换为张量,与tf.TextLineReader搭配使用
    tf.decode_raw(bytes,out_type,little_endian = None,name = None) //将字节转换为一个数字向量表示,字节为一字符串类型的张量,与函数tf.FixedLengthRecordReader搭配使用
    

    生成文件队列

    将文件名列表交给tf.train.string_input_producer函数。string_input_producer来生成一个先入先出的队列,文件阅读器会需要它们来取数据。string_input_producer提供的可配置参数来设置文件名乱序和最大的训练迭代数,QueueRunner会为每次迭代(epoch)将所有的文件名加入文件名队列中,如果shuffle=True的话,会对文件名进行乱序处理。一过程是比较均匀的,因此它可以产生均衡的文件名队列。

    这个QueueRunner工作线程是独立于文件阅读器的线程,因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。根据你的文件格式,选择对应的文件阅读器,然后将文件名队列提供给阅读器的 read 方法。阅读器的read方法会输出一个键来表征输入的文件和其中纪录(对于调试非常有用),同时得到一个字符串标量,这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

    每次read的执行都会从文件中读取一行内容,注意,(这与图片和TfRecords读取不一样),decode_csv操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。在调用run或者eval去执行read之前,你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

    举个栗子

    import tensorflow as tf
    import os
    # 忽略不必要的警告信息
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    """
    图片读取流程和二进制文件读取流程的开关
    """
    PictureReader = 0
    BinaryReader = 0
    TfrecordsReader = 1
    
    class CifarRead(object):
        """
        完成读取二进制文件,写进tfrecords,读取tfrecords
        """
        def __init__(self, filelist):
            # 文件列表
            self.filelist = filelist
            # 定义图片属性
            self.height = 32
            self.width = 32
            self.channel = 3
            # 二进制文件每张图片的字节
            self.label_bytes = 1
            self.image_bytes = self.height * self.width * self.channel
            self.bytes = self.image_bytes + self.label_bytes
    
        def read_and_decode(self):
            # 1、构造文件队列
            file_queue = tf.train.string_input_producer(self.filelist)
            # 2、构造二进制文件读取器,读取内容,每个样本的字节数
            reader = tf.FixedLengthRecordReader(self.bytes)
            key, value = reader.read(file_queue)
            # 3、对图片数据解码
            label_image = tf.decode_raw(value, tf.uint8)
            print(label_image)
            # 4、分割出图片值和标签值,及特征和目标值
            label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
            image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
            print(label, image)
            # 5、对图片的特征数据进行改变 【3072->32×32×3】
            image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
            print(label, image_reshape)
            # 6、批处理数据
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
            print(image_batch, label_batch)
    
            return image_batch, label_batch
    
        def write_to_tfrecords(self, image_batch, label_batch):
            """
            将图片的特征值和目标值存进tfrecords
            :param image_batch: 10张图片的特征值
            :param label_batch: 10张图片的目标值
            :return: None
            """
            # 1、建立TFRecords存储器
            writer = tf.python_io.TFRecordWriter("./cifar/cifar.tfrecords")
            # 2、循环所有样本写入文件,每张图片样本都要构造example协议
            for i in range(10):
                # 取出第i个图片数据的特征值和目标值
                image = image_batch[i].eval().tostring()
                label = int(label_batch[i].eval()[0])
                # 构造一个样本的example
                example = tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))
                # 写入单独的样本
                writer.write(example.SerializeToString())
            # 关闭
            writer.close()
    
            return None
    
        def read_from_tfrecords(self):
    
            # 1、构造文件队列
            file_queue = tf.train.string_input_producer(["./cifar/cifar.tfrecords"])
            # 2、构造文件阅读器,读取内容example,value=一个样本的序列化example
            reader = tf.TFRecordReader()
            key, value = reader.read(file_queue)
            # 3、解析example
            features = tf.parse_single_example(value, features={
                "image": tf.FixedLenFeature([], tf.string),
                "label": tf.FixedLenFeature([], tf.int64),
            })
            print(features['image'], features['label'])
            # 4、解码内容,如果读取的内容格式是string,需要解码,如果是int64,float32不需要解码
            image = tf.decode_raw(features["image"], tf.uint8)
            # 固定图片形状,方便批处理
            image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
            label = tf.cast(features["label"], tf.int32)
            print(image_reshape, label)
            # 进行批处理
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    
            return image_batch, label_batch
    
    def picread(filelist):
        """
        读取大黄蜂图片并转换成张量
        :param filelist: 文件路径+文件名的列表
        :return: 每张图片的张量
        """
        # 1、构造文件队列
        file_queue = tf.train.string_input_producer(filelist)
        # 2、构造阅读器读取图片内容(默认读取一张图片)
        reader = tf.WholeFileReader()
        key, value = reader.read(file_queue)
        print(key, value)
        # 3、对图片数据解码
        image = tf.image.decode_jpeg(value)
        print(image)
        # 4、处理图片,统一大小
        image_resize = tf.image.resize_images(image, [200, 200])
        # 注意:一定要把样本的形状固定,[200,200,3] ,在批处理的时候要求所有数据形状必须定义
        image_resize.set_shape([200, 200, 3])
        print(image_resize)
        # 5、进行批处理
        image_batch = tf.train.batch([image_resize], batch_size=20, num_threads=1, capacity=20)
        print(image_batch)
    
        return image_batch
    
    if __name__=="__main__":
    
        if(PictureReader == 1):
            # 1、找到文件,放入列表
            filename = os.listdir("./images/")
            print(filename)
            filelist = [os.path.join("./images/", file) for file in filename]
            print(filelist)
            image_batch = picread(filelist)
            with tf.Session() as sess:
                # 定义线程协调器
                coord = tf.train.Coordinator()
                # 开启读取文件的线程
                threads = tf.train.start_queue_runners(sess, coord=coord)
                # 打印读取内容
                print(sess.run(image_batch))
                # 回收子线程
                coord.request_stop()
                coord.join(threads)
    
        if(BinaryReader == 1):
            # 1、找到文件,放入列表
            file_name = os.listdir("./cifar/")
            print(file_name)
            filelist = [os.path.join("./cifar/", file) for file in file_name if file[-3:] == "bin"]
            print(filelist)
            cf = CifarRead(filelist)
            image_batch, label_batch = cf.read_and_decode()
            with tf.Session() as sess:
                # 定义线程协调器
                coord = tf.train.Coordinator()
                # 开启读取文件的线程
                threads = tf.train.start_queue_runners(sess, coord=coord)
                # 存进tfrecords文件
                print("Start...")
                cf.write_to_tfrecords(image_batch, label_batch)
                print("Finish...")
                # 打印读取内容
                print(sess.run([image_batch, label_batch]))
                # 回收子线程
                coord.request_stop()
                coord.join(threads)
    
        if(TfrecordsReader == 1):
            # 1、找到文件,放入列表
            file_name = os.listdir("./cifar/")
            print(file_name)
            filelist = [os.path.join("./cifar/", file) for file in file_name if file[-3:] == "bin"]
            print(filelist)
            cf = CifarRead(filelist)
            image_batch, label_batch = cf.read_from_tfrecords()
            with tf.Session() as sess:
                # 定义线程协调器
                coord = tf.train.Coordinator()
                # 开启读取文件的线程
                threads = tf.train.start_queue_runners(sess, coord=coord)
                # 打印读取内容
                print(sess.run([image_batch, label_batch]))
                # 回收子线程
                coord.request_stop()
                coord.join(threads)
    

    PS:最近有些迷茫,不知道这样坚持的意义在哪,只知道学的越多,发现自己不知道的越多。不过转念一想,我只是把别人躺床上刷抖音的时间用在所学上,相比那些当夜转瞬即逝的多巴胺,记录才是永恒的肌肉电击。谨此致敬,那些已经回不去的曾经,这些尚处迷茫的现在,和即将捉摸不定的未来!

    相关文章

      网友评论

          本文标题:TensorFlow自学第4篇——文件读写

          本文链接:https://www.haomeiwen.com/subject/puvodqtx.html