美文网首页深度学习我爱编程
TensorFlow 自定义生成 .record 文件

TensorFlow 自定义生成 .record 文件

作者: 公输睚信 | 来源:发表于2018-04-07 20:06 被阅读268次

    一、生成 .record 文件

            前面的文章 TensorFlow 训练自己的目标检测器 中的第二部分第 2 小节中我们已经预先说过,会在后续的文章中阐述怎么自定义的将图像转化为 .record 文件,今天我们就来说一说这件事。

            在文章 TensorFlow 训练 CNN 分类器 中我们生成了 50000 张 28 x 28 像素的图像,我们的目标就是将这些图像全部写入到一个后缀为 .record 的文件中。 .recordtfrecord 文件是 TensorFlow 中的标准数据读写格式,它是一种能够高效读写的二进制文件,能够快速的复制、移动、读写和存储等。

            在文章 TensorFlow 训练 CNN 分类器 和文章 TensorFlow-slim 训练 CNN 分类模型 中,我们在训练模型时导入数据的方式都是一次性的将所有图像读入,然后循环的从中选择一个批量来训练。这对于小数据集来说不会产生问题,但如果训练数据异常大,那么很可能由于内存限制无法一次性将说有数据导入,这样前面的训练方式便不能采用了。此时,我们可以将数据转化为 .record 文件格式,然后再分批次的、逐步的读入 .record 文件进行训练。

            要将图像写入 .record 文件,首先要将图像编码为字符或数字特征,这需要调用类 tf.train.Feature。然后,在调用 tf.train.Example 将特征写入协议缓冲区。最后,通过类 tf.python_io.TFRecordWriter 将数据写入到 .record 文件中。比如,我们将前面提到的 50000 张图像写入 train.record 文件,使用如下代码(命名为 generate_tfrecord.py):

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Mon Mar 26 09:02:10 2018
    
    @author: shirhe-lyh
    """
    
    """Generate tfrecord file from images.
    
    Example Usage:
    ---------------
    python3 train.py \
        --images_path: Path to the training images (directory).
        --output_path: Path to .record.
    """
    
    import glob
    import io
    import os
    import tensorflow as tf
    
    from PIL import Image
    
    flags = tf.app.flags
    
    flags.DEFINE_string('images_path', None, 'Path to images (directory).')
    flags.DEFINE_string('output_path', None, 'Path to output tfrecord file.')
    FLAGS = flags.FLAGS
    
    
    def int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    
    def int64_list_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
    
    def bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    
    def bytes_list_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
    
    
    def float_list_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    
    
    def create_tf_example(image_path):
        with tf.gfile.GFile(image_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
        width, height = image.size
        label = int(image_path.split('_')[-1].split('.')[0])
        
        tf_example = tf.train.Example(
            features=tf.train.Features(feature={
                'image/encoded': bytes_feature(encoded_jpg),
                'image/format': bytes_feature('jpg'.encode()),
                'image/class/label': int64_feature(label),
                'image/height': int64_feature(height),
                'image/width': int64_feature(width)}))
        return tf_example
    
    
    def generate_tfrecord(images_path, output_path):
        writer = tf.python_io.TFRecordWriter(output_path)
        for image_file in glob.glob(images_path):
            tf_example = create_tf_example(image_file)
            writer.write(tf_example.SerializeToString())
        writer.close()
        
        
    def main(_):
        images_path = os.path.join(FLAGS.images_path, '*.jpg')
        images_record_path = FLAGS.output_path
        generate_tfrecord(images_path, images_record_path)
        
        
    if __name__ == '__main__':
        tf.app.run()
    

    在该文件目录的终端执行:

    python3 generate_tfrecord.py \
        --images_path /home/.../datasets/images \
        --output_path /home/.../datasets/train.record
    

    便会在输出路径下生成 train.record 文件。以上代码中,最重要的部分是:1. 函数 create_tf_example,该函数首先得到图像的二进制格式、图像的宽和高、以及图像对应的类标号等,然后将图像的这些信息写入协议缓冲区;2. 函数 generate_tfrecord,该函数使用 tf.python_io.TFRecordWriter 类将协议缓冲区内的数据写入到 .record 文件中。

    二、读取 .record 文件

            一旦将图像转化为了 .record 文件,接下来我们关心的就是怎么读取这个 .record 文件用于模型训练了。这可以借助我们前面使用过的模块 tf.contrib.slim

    slim = tf.contrib.slim
    
    def get_record_dataset(record_path,
                           reader=None, image_shape=[28, 28, 3], 
                           num_samples=50000, num_classes=10):
        """Get a tensorflow record file."""
        if not reader:
            reader = tf.TFRecordReader
            
        keys_to_features = {
            'image/encoded': 
                tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format': 
                tf.FixedLenFeature((), tf.string, default_value='jpeg'),
            'image/class/label': 
                tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                                   dtype=tf.int64))}
            
        items_to_handlers = {
            'image': slim.tfexample_decoder.Image(shape=image_shape, 
                                                  #image_key='image/encoded',
                                                  #format_key='image/format',
                                                  channels=3),
            'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
        
        decoder = slim.tfexample_decoder.TFExampleDecoder(
            keys_to_features, items_to_handlers)
        
        labels_to_names = None
        items_to_descriptions = {
            'image': 'An image with shape image_shape.',
            'label': 'A single integer between 0 and 9.'}
        return slim.dataset.Dataset(
            data_sources=record_path,
            reader=reader,
            decoder=decoder,
            num_samples=num_samples,
            num_classes=num_classes,
            items_to_descriptions=items_to_descriptions,
            labels_to_names=labels_to_names)
    

    主要是借助了 tf.contib.slim 模块中的

    slim.dataset.Dataset(data_sources, reader, decoder,
                      num_samples, items_to_descriptions,
                      **kwargs)
    

    slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                            items_to_handlers)
    

    这两个类。

            使用时,直接传入 train.record 路径即可:

    dataset = get_record_dataset('./xxx/train.record')
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    

    函数 get_record_dataset 返回 slim.dataset.Dataset 类的一个对象,之后通过类

    slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=1,
                                                   reader_kwargs=None,
                                                   shuffle=True, num_epochs=None,
                                                   common_queue_capacity=256,
                                                   common_queue_min=128,
                                                   record_key='record_key',
                                                   seed=None, scope=None)
    

    get 方法得到图像和类标号的序列数据。参数 num_readers=1 表示一次读取一个数据,即一次读取一张图像,因此实际使用是还需要使用函数 tf.train.batch 将数据形成批量再用于训练,见下一篇文章。

    预告:下一篇文章将说明怎么完全使用 tf.contrib.slim 来构建和训练模型。

    相关文章

      网友评论

        本文标题:TensorFlow 自定义生成 .record 文件

        本文链接:https://www.haomeiwen.com/subject/svgyhftx.html