美文网首页
Tensorflow针对不定尺寸的图片读写tfrecord文件总

Tensorflow针对不定尺寸的图片读写tfrecord文件总

作者: CapsulE_07 | 来源:发表于2018-11-22 17:08 被阅读0次

    介绍

    最近在读取tfrecord时,遇到了关于tensorf shape的问题。

    我们需要知道,大多数情况下图片进行encode编码保存在tfrecord时 是一个一维张量,shape为(1,)。 而在输入神经网络之前,我们必须要将这个图片张量reshape成一个合乎网络结构需求的三维张量。
    在针对这样的需求时,我们会发现,大部分同学会选择在生成tfrecord前就定义好网络的输入shape,例如[224,224,3], 然后将所有的图片先reshape成这个大小,接着存储在tfrecord中。
    这种方式的优点在于提前完成的reshape,避免了后续很多的shape uncompatible 的问题,以及后续训练中不用再对图片进行reshape,加快了训练速度。
    缺点在于,限制了网络输入尺寸的定义。每修改一次神经网络的输入shape。

    当我们需要从存储着不定尺寸图片的tfrecord读取数据时, 我们是无法直接将图片reshape成指定的网络结构输入尺寸的。例如图片大小 [667,1085,3]。显然,我们无法直接将其reshape成 [224,224,3]的。那么我们该如何处理呢?

    按照思路,我们应该先将图片的一维tensor 转换成三维tensor, 然后再利用 tf.image库中不同的reshape 操作,将三维图片tensor转换为需要的 tensor大小。

    按照这种思路,在这里,我总结了两种读写tfrecord的方式,并对这两种方式的不同点,尤其是容易导致bug的地方进行了整理。

    第一种: 利用slim.dataset.Dataset读写tfrecord文件,这种方式常见于利用slim库进行目标检测等网络的实现过程中。
    第二种:tf.parse_single_example 是更为常见的一种方式

    利用slim.dataset.Dataset读写tfrecord文件

    利用这个这个接口读写tfrecord非常的方便。它的神奇之处在于,
    它不需要图片宽高的信息,只需要其二进制string tensor。 这个接口会自动返回一个三维图片tensor。 在此基础上,我们可以很方便的对其进行reshape,然后输入神经网络。
    具体步骤如下:
    在生成tfrecord文件时,我们需要先定义 tf_example的写入格式,然后在将图片文件依据这个写入格式,生成tfrecord文件

    • 定义 tf_example的写入特征
    def int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    def int64_list_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
    def bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def bytes_list_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
    
    def float_list_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    
    def create_tf_example(image_path, label, resize_size=None):
        with tf.gfile.GFile(image_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
    
        # 对于可能存在RGBA 4通道的图片进行处理
        image,is_process = process_image_channels(image)
    
        # 如有必要,那么就在生成tfrecord时即进行resize
        width, height = image.size
        if resize_size is not None:
            if width > height:
                width = int(width * resize_size / height)
                height = resize_size
            else:
                width = resize_size
                height = int(height * resize_size / width)
            image = image.resize((width, height), Image.ANTIALIAS)
        # update encode_jpg
        if is_process or resize_size is not None:
            bytes_io = io.BytesIO()
            image.save(bytes_io, format='JPEG')
            encoded_jpg = bytes_io.getvalue()
    
        tf_example = tf.train.Example(
            features=tf.train.Features(feature={
                'image/encoded': bytes_feature(encoded_jpg),
                'image/format': bytes_feature('jpg'.encode()),
                'image/class/label': int64_feature(label),
                'image/height': int64_feature(height),
                'image/width': int64_feature(width)}))
        return tf_example
    
    • 生成完整的tfrecord文件
      在定义完对应的tf_example 方式后,我们可以遍历图片文件,生成完整的tfrecord文件了。
    def generate_tfrecord(annotation_dict, output_path, resize_size=None):
        num_valid_tf_example = 0
        writer = tf.python_io.TFRecordWriter(output_path)
        for image_path, label in annotation_dict.items():
            if not tf.gfile.GFile(image_path):
                print('%s does not exist.' % image_path)
                continue
            tf_example = create_tf_example(image_path, label, resize_size)
            if tf_example:
                writer.write(tf_example.SerializeToString())
                num_valid_tf_example += 1
    
                if num_valid_tf_example % 100 == 0:
                    print('Create %d TF_Example.' % num_valid_tf_example)
        writer.close()
        print('Total create TF_Example: %d' % num_valid_tf_example)
    

    对应着,在读取tfrecord时,slim提供了 slim.dataset.Dataset 的API接口,非常方便对读入的tfrecord数据进行操作。

    def get_record_dataset(record_path,
                           reader=None, 
                           num_samples=50000, 
                           num_classes=32):
        """Get a tensorflow record file.
        
        Args:
            
        """
        if not reader:
            reader = tf.TFRecordReader
            
        keys_to_features = {
            'image/encoded': 
                tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format': 
                tf.FixedLenFeature((), tf.string, default_value='jpeg'),
            'image/class/label': 
                tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                                   dtype=tf.int64))}
            
        items_to_handlers = {
            'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                                  format_key='image/format'),
            'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
        decoder = slim.tfexample_decoder.TFExampleDecoder(
            keys_to_features, items_to_handlers)
        
        labels_to_names = None
        items_to_descriptions = {
            'image': 'An image with shape image_shape.',
            'label': 'A single integer.'}
        return slim.dataset.Dataset(
            data_sources=record_path,
            reader=reader,
            decoder=decoder,
            num_samples=num_samples,
            num_classes=num_classes,
            items_to_descriptions=items_to_descriptions,
            labels_to_names=labels_to_names)
    
    

    在返回了slim.dataset.Dataset这个slim支持的data封装后, 我们可直接对返回的图片数据进行reshape,保证这个图片张量的shape与网络结构的输入层shape一致。

       dataset = get_record_dataset(FLAGS.record_path, num_samples=num_samples, 
                                     num_classes=FLAGS.num_classes)
        data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
        image, label = data_provider.get(['image', 'label'])
        
        # 输出当前tensor的静态shape 和动态shape,与另一种读取方式进行对比
        print("----------tf.shape(image): ",tf.shape(image))
        print("----------image.get_shape(): ",image.get_shape())
        image = _fixed_sides_resize(image, output_height=368, output_width=368)
            
        inputs, labels = tf.train.batch([image, label],
                                        batch_size=FLAGS.batch_size,
                                        #capacity=5*FLAGS.batch_size,
                                        allow_smaller_final_batch=True)
    

    其中,对三维图片张量进行reshape的代码如下

    def _fixed_sides_resize(image, output_height, output_width):
        """Resize images by fixed sides.
        
        Args:
            image: A 3-D image `Tensor`.
            output_height: The height of the image after preprocessing.
            output_width: The width of the image after preprocessing.
    
        Returns:
            resized_image: A 3-D tensor containing the resized image.
        """
        output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
        output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
    
        image = tf.expand_dims(image, 0)
        resized_image = tf.image.resize_nearest_neighbor(
            image, [output_height, output_width], align_corners=False)
        resized_image = tf.squeeze(resized_image)
        resized_image.set_shape([None, None, 3])
        return resized_image
    
    

    完成了这几步之后,我们就可以利用image 和 label 进行神经网络训练了。

    利用tf.parse_single_example 读写tfrecord文件

    这种方式我们需要自己手动将一维的图片tensor,先还原成三维图片tensor。 因为每一张图片的shape不相同。那么我们需要将图片的shape也存入tfrecord文件中。当我们从tfrecord文件中读取时,我们先利用tf.reshape将一维图片张量还原成三维图片张量,再reshape规定的网络输入尺寸。

    • 照例,此处的重点在于tf_example的构建。在这一部分,我将图片的shape作为一个feature,也存入了tfrecord里面。 那么,在对张量的还原时,我们可以利用这个三维的shape tensor,
    def create_tf_example(image_path, label, resize_size=None):
        with tf.gfile.GFile(image_path, 'rb') as fid:
            encoded_jpg = fid.read()
        encoded_jpg_io = io.BytesIO(encoded_jpg)
        image = Image.open(encoded_jpg_io)
        # 对于RGBA 4通道的图片进行处理
        image,is_process = process_image_channels(image)
    
        # Resize
        width, height = image.size
        if resize_size is not None:
            if width > height:
                width = int(width * resize_size / height)
                height = resize_size
            else:
                width = resize_size
                height = int(height * resize_size / width)
            image = image.resize((width, height), Image.ANTIALIAS)
        
        img_array = np.asarray(image)
        shape = img_array.shape
        byte_image = image.tobytes()
        
        tf_example = tf.train.Example(
            features=tf.train.Features(feature={
                'image': bytes_feature(byte_image),
                'label': int64_feature(label),
                'img_shape': int64_list_feature(shape)}))
        return tf_example
    
    • 在完成这个后,我们仍旧可以使用上述提及的generate_tfrecord 函数来生成对应的tfrecord

    • 那么,对应这种方式生成的tfrecord文件,我们该如何读取呢?
      在这里,我给出对应的parse_example函数就足以了。

    def parse(serialized):
        # Define a dict with the data-names and types we expect to
        # find in the TFRecords file.
        # It is a bit awkward that this needs to be specified again,
        # because it could have been written in the header of the
        # TFRecords file instead.
    
        features = {
            'image':
                tf.FixedLenFeature((), tf.string, default_value=''),
            'label':
                tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1],
                                                                         dtype=tf.int64)),
            'img_shape': 
                tf.FixedLenFeature(shape=(3,), dtype=tf.int64)}
    
        # Parse the serialized data so we get a dict with our data.
        parsed_example = tf.parse_single_example(
            serialized=serialized, features=features)
    
        # Get the image as raw bytes.
        image_raw = parsed_example['image']
    
        # Decode the raw bytes so it becomes a tensor with type.
        image = tf.decode_raw(image_raw, tf.uint8)
        # The type is now uint8 but we need it to be float.
        image = tf.cast(image, tf.float32)
        
        shape = parsed_example['img_shape']
        
        image = tf.reshape(image,shape=shape)
        
        if not (shape[0] == shape[1] == default_img_size):
            image = _fixed_sides_resize(image,default_img_size,default_img_size)
        
        image.set_shape([default_img_size,default_img_size,3])
        label = parsed_example['label']
        # The image and label are now correct TensorFlow types.
        return image, label
    

    在这里,读写tfrecord的重要流程就已经展现好了。

    对比

    这两种方式有一个比较重要的区别,那就是制作tfrecord时存储的图片信息不同。
    使用slim api时 我们制作tfrecord 时,相关代码为

        with tf.gfile.GFile(image_path, 'rb') as fid:
            encoded_jpg = fid.read()
    

    当我们使用第二种方式时,制作tfrecord时存储的图片信息的相关代码如下所示。

    image = Image.open(img_dir)
    byte_image = image.tobytes()
    

    第一种方式保存的图片信息,其字节数不等于图片的height, width, channel的乘积。 所以不能用 第二种的方式去读取这种方式存储的tfrecord。 会出现 reshape时 维度不对的错误。 当然,使用slim.dataset.Dataset 则不需要考虑这个问题了。 网络上使用slim.dataset.Dataset 来加载tfrecord的方式,都是使用第一种方式存储的tfrecord数据。

    第二种方式,其存储的图片字节大小等于图片的height, width, channel的乘积。所以它可以直接用tf.reshape直接将原图矩阵还原回来,然后再进行下一步的reshape操作。

    总结

    之所以写这篇文章,是因为网络上针对不定尺寸图片tfrecord读取的解决方案不是很完善。
    例如 https://stackoverflow.com/questions/40258943/using-height-width-information-stored-in-a-tfrecords-file-to-set-shape-of-a-ten
    将height, width,channel 分别存入tfrecord,然后按照提问者描述这样是不成功的。
    再例如https://stackoverflow.com/questions/35028173/how-to-read-images-with-different-size-in-a-tfrecord-file 提供的解决方案

    image_rows = tf.cast(features['rows'], tf.int32)
    image_cols = tf.cast(features['cols'], tf.int32)
    image_data = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image_data, tf.pack([image_rows, image_cols, 3]))
    

    这种方式在tf.reshape阶段会报错,因为我们无法将 两个tensor和一个int数值组合起来。最完善的方式是直接将shape作为一个整体存入tfrecord中,最终读取出来就是一个张量了。

    相关文章

      网友评论

          本文标题:Tensorflow针对不定尺寸的图片读写tfrecord文件总

          本文链接:https://www.haomeiwen.com/subject/mqgkqqtx.html