[0.2] Tensorflow踩坑记之头疼的tf.data

作者: 澜夕 | 来源:发表于2018-08-23 14:38 被阅读2次

    今天尝试总结一下 tf.data 这个API的一些用法吧。之所以会用到这个API,是因为需要处理的数据量很大,而且数据均是分布式的存储在多台服务器上,所以没有办法采用传统的喂数据方式,而是运用了 tf.data 对数据进行了相应的预处理,并且最近正赶上总结需要,尝试写一下关于 tf.data 的一些用法,有错误的地方一定告诉我哈。

    Tensorflow的数据读取

    先来看一下Tensorflow的数据读取机制吧

    这一篇文章对于 tensorflow的数据读取机制 讲解得很不错,大噶可以先看一下,有一个了解。

    Dataset API是怎么用的呢

    虽然上面的资料关于 tf.data 讲解得都很好,但是我没有找到一个很完整滴运用 tf.data.TextLineDataset()tf.data.TFRecordDataset() 的例子,所以才想尝试写一写这篇总结。

    MNIST的经典例子

    本篇博客结合 mnist 的经典例子,针对不同的源数据:csv数据和tfrecord数据,分别运用 tf.data.TextLineDataset()tf.data.TFRecordDataset() 创建不同的 Dataset 并运用四种不同的 Iterator ,分别是 单次,可初始化,可重新初始化,以及可馈送迭代器 的方式实现对源数据的预处理工作。

    我将相关的资料放在了澜子的Github 上,欢迎互粉哇(星星眼)。其中包括了所需的 后缀名为csv和tfrecords的源数据 (data的文件夹),以及在 jupyter notebook实现的具体代码 (tf_dataset_learn.ipynb)。

    如果有需要的同学可以直接
    git clone https://github.com/lanhongvp/tensorflow_dataset_learn.git
    然后用 jupyter 跑一跑看看输出,这样可以有一个比较直观的认识。关于 Git和Github 的使用,大噶可以看我VSCODE_GIT这一篇博客啦。接下来,针对MNIST例子做一个简单的说明吧。

    tf.data.TFRecordDataset() & make_one_shot_iterator()

    tf.data.TFRecordDataset() 输入参数直接是后缀名为tfrecords的文件路径,正因如此,即可解决数据量过大,导致无法单机训练的问题。本篇博客中,文件路径即为/Users/honglan/Desktop/train_output.tfrecords,此处是我自己电脑上的路径,大家可以 根据自己的需要修改为对应的文件路径。
    make_one_shot_iterator() 即为单次迭代器,是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。
    配合 MNIST数据集以及tf.data.TFRecordDataset(),实现代码如下。

    # Validate tf.data.TFRecordDataset() using make_one_shot_iterator()
    import tensorflow as tf
    import numpy as np
    
    num_epochs = 2
    num_class = 10
    sess = tf.Session()
    
    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def parser(record):
        keys_to_features = {
            "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
            "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
            "label": tf.FixedLenFeature((), tf.int64,
                                        default_value=tf.zeros([], dtype=tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
    
        # Parse the string into an array of pixels corresponding to the image
        images = tf.decode_raw(parsed["image_raw"],tf.uint8)
        images = tf.reshape(images,[28,28,1])
        labels = tf.cast(parsed['label'], tf.int32)
        labels = tf.one_hot(labels,num_class)
        pixels = tf.cast(parsed['pixels'], tf.int32)
        print("IMAGES",images)
        print("LABELS",labels)
        
        return {"image_raw": images}, labels
    
    
    filenames = ["/Users/honglan/Desktop/train_output.tfrecords"] 
    # replace the filenames with your own path
    dataset = tf.data.TFRecordDataset(filenames)
    print("DATASET",dataset)
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    print("DATASET_1",dataset)
    dataset = dataset.shuffle(buffer_size=10000)
    print("DATASET_2",dataset)
    dataset = dataset.batch(32)
    print("DATASET_3",dataset)
    dataset = dataset.repeat(num_epochs)
    print("DATASET_4",dataset)
    iterator = dataset.make_one_shot_iterator()
    
    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features, labels = iterator.get_next()
    
    print("FEATURES",features)
    print("LABELS",labels)
    print("SESS_RUN_LABELS \n",sess.run(labels))
    

    tf.data.TFRecordDataset() & Initializable iterator

    make_initializable_iterator() 为可初始化迭代器,运用此迭代器首先需要先运行显式 iterator.initializer 操作,然后才能使用。并且,可运用 可初始化迭代器实现训练集和验证集的切换
    配合 MNIST数据集 实现代码如下。

    # Validate tf.data.TFRecordDataset() using make_initializable_iterator()
    # In order to switch between train and validation data
    num_epochs = 2
    num_class = 10
    
    def parser(record):
        keys_to_features = {
            "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
            "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
            "label": tf.FixedLenFeature((), tf.int64,
                                        default_value=tf.zeros([], dtype=tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
        
        # Parse the string into an array of pixels corresponding to the image
        images = tf.decode_raw(parsed["image_raw"],tf.uint8)
        images = tf.reshape(images,[28,28,1])
        labels = tf.cast(parsed['label'], tf.int32)
        labels = tf.one_hot(labels,10)
        pixels = tf.cast(parsed['pixels'], tf.int32)
        print("IMAGES",images)
        print("LABELS",labels)
        
        return {"image_raw": images}, labels
    
    
    filenames = tf.placeholder(tf.string, shape=[None])
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(parser) # Parse the record into tensors
    # print("DATASET",dataset)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(32)
    dataset = dataset.repeat(num_epochs)
    print("DATASET",dataset)
    iterator = dataset.make_initializable_iterator()
    features, labels = iterator.get_next()
    print("ITERATOR",iterator)
    print("FEATURES",features)
    print("LABELS",labels)
    
    
    # Initialize `iterator` with training data.
    training_filenames = ["/Users/honglan/Desktop/train_output.tfrecords"] 
    # replace the filenames with your own path
    sess.run(iterator.initializer,feed_dict={filenames: training_filenames})
    print("TRAIN\n",sess.run(labels))
    # print(sess.run(features))
    
    # Initialize `iterator` with validation data.
    validation_filenames = ["/Users/honglan/Desktop/val_output.tfrecords"] 
    # replace the filenames with your own path
    sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
    print("VAL\n",sess.run(labels))
    
    

    tf.data.TextLineDataset() & Reinitializable iterator

    tf.data.TextLineDataset(),输入参数可以是后缀名为csv或者是txt的源数据的文件路径。
    此处用的迭代器是 Reinitializable iterator,即为可重新初始化迭代器。官方定义如下。配合 MNIST数据集 实现代码见第二部分。

    可重新初始化迭代器可以通过多个不同的 Dataset 对象进行初始化。例如,您可能有一个训练输入管道,它会对输入图片进行随机扰动来改善泛化;还有一个验证输入管道,它会评估对未修改数据的预测。这些管道通常会使用不同的 Dataset 对象,这些对象具有相同的结构(即每个组件具有相同类型和兼容形状)。

    # validate tf.data.TextLineDataset() using Reinitializable iterator
    # In order to switch between train and validation data
    
    def decode_line(line):
        # Decode the line to tensor
        record_defaults = [[1.0] for col in range(785)]
        items = tf.decode_csv(line, record_defaults)
        features = items[1:785]
        label = items[0]
    
        features = tf.cast(features, tf.float32)
        features = tf.reshape(features,[28,28,1])
        label = tf.cast(label, tf.int64)
        label = tf.one_hot(label,num_class)
        return features,label
    
    
    def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):
        """create dataset for train and validation dataset"""
        dataset = tf.data.TextLineDataset(filename).skip(1)
        if n_repeats > 0:
            dataset = dataset.repeat(n_repeats)         # for train
        # dataset = dataset.map(decode_line).map(normalize)  
        dataset = dataset.map(decode_line)    
        # decode and normalize
        if is_shuffle:
            dataset = dataset.shuffle(10000)            # shuffle
        dataset = dataset.batch(batch_size)
        return dataset
    
    
    training_filenames = ["/Users/honglan/Desktop/train.csv"] 
    # replace the filenames with your own path
    validation_filenames = ["/Users/honglan/Desktop/val.csv"] 
    # replace the filenames with your own path
    
    # Create different datasets
    training_dataset = create_dataset(training_filenames, batch_size=32, \
                                      is_shuffle=True, n_repeats=num_epochs) # train_filename
    validation_dataset = create_dataset(validation_filenames, batch_size=32, \
                                      is_shuffle=True, n_repeats=num_epochs) # val_filename
    
    # A reinitializable iterator is defined by its structure. We could use the
    # `output_types` and `output_shapes` properties of either `training_dataset`
    # or `validation_dataset` here, because they are compatible.
    iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                               training_dataset.output_shapes)
    features, labels = iterator.get_next()
    
    training_init_op = iterator.make_initializer(training_dataset)
    validation_init_op = iterator.make_initializer(validation_dataset)
    
    # Using reinitializable iterator to alternate between training and validation.
    sess.run(training_init_op)
    print("TRAIN\n",sess.run(labels))
    # print(sess.run(features))
    
    # Reinitialize `iterator` with validation data.
    sess.run(validation_init_op)
    print("VAL\n",sess.run(labels))
    
    

    tf.data.TextLineDataset() & Feedable iterator.

    数据集读取方式同上一部分一样,运用tf.data.TextLineDataset()此处运用的迭代器是 可馈送迭代器,其可以与 tf.placeholder 一起使用,通过熟悉的 feed_dict 机制选择每次调用 tf.Session.run 时所使用的 Iterator。并使用 tf.data.Iterator.from_string_handle定义一个可让在两个数据集之间切换的可馈送迭代器,结合 MNIST数据集 的代码如下

    # validate tf.data.TextLineDataset() using two different iterator
    # In order to switch between train and validation data
    
    def decode_line(line):
        # Decode the line to tensor
        record_defaults = [[1.0] for col in range(785)]
        items = tf.decode_csv(line, record_defaults)
        features = items[1:785]
        label = items[0]
    
        features = tf.cast(features, tf.float32)
        features = tf.reshape(features,[28,28])
        label = tf.cast(label, tf.int64)
        label = tf.one_hot(label,num_class)
        return features,label
    
    
    def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):
        """create dataset for train and validation dataset"""
        dataset = tf.data.TextLineDataset(filename).skip(1)
        if n_repeats > 0:
            dataset = dataset.repeat(n_repeats)         # for train
        # dataset = dataset.map(decode_line).map(normalize)  
        dataset = dataset.map(decode_line)    
        # decode and normalize
        if is_shuffle:
            dataset = dataset.shuffle(10000)            # shuffle
        dataset = dataset.batch(batch_size)
        return dataset
    
    
    training_filenames = ["/Users/honglan/Desktop/train.csv"] 
    # replace the filenames with your own path
    validation_filenames = ["/Users/honglan/Desktop/val.csv"] 
    # replace the filenames with your own path
    
    # Create different datasets
    training_dataset = create_dataset(training_filenames, batch_size=32, \
                                      is_shuffle=True, n_repeats=num_epochs) # train_filename
    validation_dataset = create_dataset(validation_filenames, batch_size=32, \
                                      is_shuffle=True, n_repeats=num_epochs) # val_filename
    
    # A feedable iterator is defined by a handle placeholder and its structure. We
    # could use the `output_types` and `output_shapes` properties of either
    # `training_dataset` or `validation_dataset` here, because they have
    # identical structure.
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_dataset.output_types, training_dataset.output_shapes)
    features, labels = iterator.get_next()
    
    # You can use feedable iterators with a variety of different kinds of iterator
    # (such as one-shot and initializable iterators).
    training_iterator = training_dataset.make_one_shot_iterator()
    validation_iterator = validation_dataset.make_initializable_iterator()
    
    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
    # and used to feed the `handle` placeholder.
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    
    # Using different handle to alternate between training and validation.
    print("TRAIN\n",sess.run(labels, feed_dict={handle: training_handle}))
    # print(sess.run(features))
    
    # Initialize `iterator` with validation data.
    sess.run(validation_iterator.initializer)
    print("VAL\n",sess.run(labels, feed_dict={handle: validation_handle}))
    
    

    小结

    • 运用tfrecords处理数据的速度明显加快
    • 可以根据自身需要选择不同的iterator方式对源数据进行预处理
    • 单机训练时也可以采用 tf.data中API的相应处理方式

    相关文章

      网友评论

        本文标题:[0.2] Tensorflow踩坑记之头疼的tf.data

        本文链接:https://www.haomeiwen.com/subject/bgzpiftx.html