美文网首页
5、TFRecord(管理数据)

5、TFRecord(管理数据)

作者: MakeStart | 来源:发表于2019-11-09 20:15 被阅读0次

    TFRecord产生的背景:
    一般情况下数据集经常分为 train, test 文件夹,文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,不仅占用磁盘空间,读取的时候频繁访问磁盘,会非常慢。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率
    下面分为3个方面说明其如何使用:代码结构,TFRecord创建,TFRecord读取

    1、代码结构:

    image.png

    2、TFRecord创建

    #coding=utf-8
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    from datetime import datetime
    import os
    import random
    import sys
    import threading
    
    
    import numpy as np
    import tensorflow as tf
    
    tf.app.flags.DEFINE_string('train_directory', './flower_photos/',
                               'Training data directory')
    tf.app.flags.DEFINE_string('validation_directory', './flower_photos/',
                               'Validation data directory')
    tf.app.flags.DEFINE_string('output_directory', './data/',
                               'Output data directory')
    
    tf.app.flags.DEFINE_integer('train_shards', 2,
                                'Number of shards in training TFRecord files.')
    tf.app.flags.DEFINE_integer('validation_shards', 0,
                                'Number of shards in validation TFRecord files.')
    
    tf.app.flags.DEFINE_integer('num_threads', 2,
                                'Number of threads to preprocess the images.')
    
    # The labels file contains a list of valid labels are held in this file.
    # Assumes that the file contains entries as such:
    #   dog
    #   cat
    #   flower
    # where each line corresponds to a label. We map each label contained in
    # the file to an integer corresponding to the line number starting from 0.
    tf.app.flags.DEFINE_string('labels_file', './flower_label.txt', 'Labels file')
    
    
    FLAGS = tf.app.flags.FLAGS
    
    
    def _int64_feature(value):
      """Wrapper for inserting int64 features into Example proto."""
      if not isinstance(value, list):
        value = [value]
      return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
    
    def _bytes_feature(value):
      """Wrapper for inserting bytes features into Example proto."""
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    
    def _convert_to_example(filename, image_buffer, label, text, height, width):
      """Build an Example proto for an example.
    
      Args:
        filename: string, path to an image file, e.g., '/path/to/example.JPG'
        image_buffer: string, JPEG encoding of RGB image
        label: integer, identifier for the ground truth for the network
        text: string, unique human-readable, e.g. 'dog'
        height: integer, image height in pixels
        width: integer, image width in pixels
      Returns:
        Example proto
      """
    
      colorspace = 'RGB'
      channels = 3
      image_format = 'JPEG'
    
      example = tf.train.Example(features=tf.train.Features(feature={
          'image/height': _int64_feature(height),
          'image/width': _int64_feature(width),
          'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
          'image/channels': _int64_feature(channels),
          'image/class/label': _int64_feature(label),
          'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
          'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
          'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
          'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))}))
      return example
    
    
    class ImageCoder(object):
      """Helper class that provides TensorFlow image coding utilities."""
    
      def __init__(self):
        # Create a single Session to run all image coding calls.
        self._sess = tf.Session()
    
        # Initializes function that converts PNG to JPEG data.
        self._png_data = tf.placeholder(dtype=tf.string)
        image = tf.image.decode_png(self._png_data, channels=3)
        self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
    
        # Initializes function that decodes RGB JPEG data.
        self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
        self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
    
      def png_to_jpeg(self, image_data):
        return self._sess.run(self._png_to_jpeg,
                              feed_dict={self._png_data: image_data})
    
      def decode_jpeg(self, image_data):
        image = self._sess.run(self._decode_jpeg,
                               feed_dict={self._decode_jpeg_data: image_data})
        assert len(image.shape) == 3
        assert image.shape[2] == 3
        return image
    
    
    def _is_png(filename):
      """Determine if a file contains a PNG format image.
    
      Args:
        filename: string, path of the image file.
    
      Returns:
        boolean indicating if the image is a PNG.
      """
      return '.png' in filename
    
    
    def _process_image(filename, coder):
      """Process a single image file.
    
      Args:
        filename: string, path to an image file e.g., '/path/to/example.JPG'.
        coder: instance of ImageCoder to provide TensorFlow image coding utils.
      Returns:
        image_buffer: string, JPEG encoding of RGB image.
        height: integer, image height in pixels.
        width: integer, image width in pixels.
      """
      # Read the image file.
      with tf.gfile.FastGFile(filename, 'rb') as f:
        image_data = f.read()
    
      # Convert any PNG to JPEG's for consistency.
      if _is_png(filename):
        print('Converting PNG to JPEG for %s' % filename)
        image_data = coder.png_to_jpeg(image_data)
    
      # Decode the RGB JPEG.
      image = coder.decode_jpeg(image_data)
    
      # Check that image converted to RGB
      assert len(image.shape) == 3
      height = image.shape[0]
      width = image.shape[1]
      assert image.shape[2] == 3
    
      return image_data, height, width
    
    
    def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
                                   texts, labels, num_shards):
      """Processes and saves list of images as TFRecord in 1 thread.
    
      Args:
        coder: instance of ImageCoder to provide TensorFlow image coding utils.
        thread_index: integer, unique batch to run index is within [0, len(ranges)).
        ranges: list of pairs of integers specifying ranges of each batches to
          analyze in parallel.
        name: string, unique identifier specifying the data set
        filenames: list of strings; each string is a path to an image file
        texts: list of strings; each string is human readable, e.g. 'dog'
        labels: list of integer; each integer identifies the ground truth
        num_shards: integer number of shards for this data set.
      """
      # Each thread produces N shards where N = int(num_shards / num_threads).
      # For instance, if num_shards = 128, and the num_threads = 2, then the first
      # thread would produce shards [0, 64).
      num_threads = len(ranges)
      assert not num_shards % num_threads
      num_shards_per_batch = int(num_shards / num_threads)
    
      shard_ranges = np.linspace(ranges[thread_index][0],
                                 ranges[thread_index][1],
                                 num_shards_per_batch + 1).astype(int)
      num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
    
      counter = 0
      for s in range(num_shards_per_batch):
        # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
        shard = thread_index * num_shards_per_batch + s
        output_filename = '%s-%.5d-of-%.5d.tfrecord' % (name, shard, num_shards)
        output_file = os.path.join(FLAGS.output_directory, output_filename)
        writer = tf.python_io.TFRecordWriter(output_file)
    
        shard_counter = 0
        files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
        for i in files_in_shard:
          filename = filenames[i]
          label = labels[i]
          text = texts[i]
    
          image_buffer, height, width = _process_image(filename, coder)
    
          example = _convert_to_example(filename, image_buffer, label,
                                        text, height, width)
          writer.write(example.SerializeToString())
          shard_counter += 1
          counter += 1
    
          if not counter % 1000:
            print('%s [thread %d]: Processed %d of %d images in thread batch.' %
                  (datetime.now(), thread_index, counter, num_files_in_thread))
            sys.stdout.flush()
    
        writer.close()
        print('%s [thread %d]: Wrote %d images to %s' %
              (datetime.now(), thread_index, shard_counter, output_file))
        sys.stdout.flush()
        shard_counter = 0
      print('%s [thread %d]: Wrote %d images to %d shards.' %
            (datetime.now(), thread_index, counter, num_files_in_thread))
      sys.stdout.flush()
    
    
    def _process_image_files(name, filenames, texts, labels, num_shards):
      """Process and save list of images as TFRecord of Example protos.
    
      Args:
        name: string, unique identifier specifying the data set
        filenames: list of strings; each string is a path to an image file
        texts: list of strings; each string is human readable, e.g. 'dog'
        labels: list of integer; each integer identifies the ground truth
        num_shards: integer number of shards for this data set.
      """
      assert len(filenames) == len(texts)
      assert len(filenames) == len(labels)
    
      # Break all images into batches with a [ranges[i][0], ranges[i][1]].
      spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
      ranges = []
      for i in range(len(spacing) - 1):
        ranges.append([spacing[i], spacing[i+1]])
    
      # Launch a thread for each batch.
      print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
      sys.stdout.flush()
    
      # Create a mechanism for monitoring when all threads are finished.
      coord = tf.train.Coordinator()
    
      # Create a generic TensorFlow-based utility for converting all image codings.
      coder = ImageCoder()
    
      threads = []
      for thread_index in range(len(ranges)):
        args = (coder, thread_index, ranges, name, filenames,
                texts, labels, num_shards)
        t = threading.Thread(target=_process_image_files_batch, args=args)
        t.start()
        threads.append(t)
    
      # Wait for all the threads to terminate.
      coord.join(threads)
      print('%s: Finished writing all %d images in data set.' %
            (datetime.now(), len(filenames)))
      sys.stdout.flush()
    
    
    def _find_image_files(data_dir, labels_file):
      """Build a list of all images files and labels in the data set.
    
      Args:
        data_dir: string, path to the root directory of images.
    
          Assumes that the image data set resides in JPEG files located in
          the following directory structure.
    
            data_dir/dog/another-image.JPEG
            data_dir/dog/my-image.jpg
    
          where 'dog' is the label associated with these images.
    
        labels_file: string, path to the labels file.
    
          The list of valid labels are held in this file. Assumes that the file
          contains entries as such:
            dog
            cat
            flower
          where each line corresponds to a label. We map each label contained in
          the file to an integer starting with the integer 0 corresponding to the
          label contained in the first line.
    
      Returns:
        filenames: list of strings; each string is a path to an image file.
        texts: list of strings; each string is the class, e.g. 'dog'
        labels: list of integer; each integer identifies the ground truth.
      """
      print('目标文件夹位置: %s.' % data_dir)
      unique_labels = [l.strip() for l in tf.gfile.FastGFile(
          labels_file, 'r').readlines()]
    
      labels = []
      filenames = []
      texts = []
    
      # Leave label index 0 empty as a background class.
      label_index = 1
    
      # Construct the list of JPEG files and labels.
      for text in unique_labels:
        jpeg_file_path = '%s/%s/*' % (data_dir, text)
        try:
            matching_files = tf.gfile.Glob(jpeg_file_path)
        except:
            print (jpeg_file_path)
            continue
    
        labels.extend([label_index] * len(matching_files))
        texts.extend([text] * len(matching_files))
        filenames.extend(matching_files)
    
        label_index += 1
    
      # Shuffle the ordering of all image files in order to guarantee
      # random ordering of the images with respect to label in the
      # saved TFRecord files. Make the randomization repeatable.
      shuffled_index = list(range(len(filenames)))
      random.seed(12345)
      random.shuffle(shuffled_index)
    
      filenames = [filenames[i] for i in shuffled_index]
      texts = [texts[i] for i in shuffled_index]
      labels = [labels[i] for i in shuffled_index]
    
      print('Found %d JPEG files across %d labels inside %s.' %
            (len(filenames), len(unique_labels), data_dir))
      return filenames, texts, labels
    
    
    def _process_dataset(name, directory, num_shards, labels_file):
      """Process a complete data set and save it as a TFRecord.
    
      Args:
        name: string, unique identifier specifying the data set.
        directory: string, root path to the data set.
        num_shards: integer number of shards for this data set.
        labels_file: string, path to the labels file.
      """
      filenames, texts, labels = _find_image_files(directory, labels_file)
      _process_image_files(name, filenames, texts, labels, num_shards)
    
    
    def main(unused_argv):
      assert not FLAGS.train_shards % FLAGS.num_threads, (
          '在测试集中:线程数量应用建立文件个数想对应')
      assert not FLAGS.validation_shards % FLAGS.num_threads, (
          '在测试集中:线程数量应用建立文件个数想对应')
      print('生成数据文件夹 %s' % FLAGS.output_directory)
    
      # Run it!
      _process_dataset('train', FLAGS.train_directory,
                       FLAGS.train_shards, FLAGS.labels_file)
    """
      _process_dataset('validation', FLAGS.validation_directory,
                       FLAGS.validation_shards, FLAGS.labels_file)
    """
    
    
    if __name__ == '__main__':
      tf.app.run()
    
    

    TFRecord数据的读取:

    
    def read_and_decode(filename_queue):
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(serialized_example, features = {
            "image/encoded": tf.FixedLenFeature([], tf.string),
            "image/height": tf.FixedLenFeature([], tf.int64),
            "image/width": tf.FixedLenFeature([], tf.int64),
            "image/filename": tf.FixedLenFeature([], tf.string),
            "image/class/label": tf.FixedLenFeature([], tf.int64),})
        image_encoded = features["image/encoded"]
        image_raw = tf.image.decode_jpeg(image_encoded, channels=3)
        image_object = _image_object()
        image_object.image = tf.image.resize_image_with_crop_or_pad(image_raw, IMAGE_SIZE, IMAGE_SIZE)
        image_object.height = features["image/height"]
        image_object.width = features["image/width"]
        image_object.filename = features["image/filename"]
        image_object.label = tf.cast(features["image/class/label"], tf.int64)
        return image_object
    
    def flower_input(if_random = True, if_training = True):
        if(if_training):
            filenames = [os.path.join(DATA_DIR, "train-0000%d-of-00002.tfrecord" % i) for i in range(0, 2)]
        else:
            filenames = [os.path.join(DATA_DIR, "eval-0000%d-of-00002.tfrecord" % i) for i in range(0, 2)]
    
        for f in filenames:
            if not tf.gfile.Exists(f):
                raise ValueError("Failed to find file: " + f)
        filename_queue = tf.train.string_input_producer(filenames)
        image_object = read_and_decode(filename_queue)
        image = tf.image.per_image_standardization(image_object.image)
    #    image = image_object.image
    #    image = tf.image.adjust_gamma(tf.cast(image_object.image, tf.float32), gamma=1, gain=1) # Scale image to (0, 1)
        label = image_object.label
        filename = image_object.filename
    
        if(if_random):
            min_fraction_of_examples_in_queue = 0.4
            min_queue_examples = int(TRAINING_SET_SIZE * min_fraction_of_examples_in_queue)
            print("Filling queue with %d images before starting to train. " "This will take a few minutes." % min_queue_examples)
            num_preprocess_threads = 1
            image_batch, label_batch, filename_batch = tf.train.shuffle_batch(
                [image, label, filename],
                batch_size = BATCH_SIZE,
                num_threads = num_preprocess_threads,
                capacity = min_queue_examples + 3 * BATCH_SIZE,
                min_after_dequeue = min_queue_examples)
            return image_batch, label_batch, filename_batch
        else:
            image_batch, label_batch, filename_batch = tf.train.batch(
                [image, label, filename],
                batch_size = BATCH_SIZE,
                num_threads = 1)
            return image_batch, label_batch, filename_batch
    
    
    

    相关文章

      网友评论

          本文标题:5、TFRecord(管理数据)

          本文链接:https://www.haomeiwen.com/subject/lgwnbctx.html