美文网首页
4_datasets_quickstart

4_datasets_quickstart

作者: happy_19 | 来源:发表于2018-06-12 20:06 被阅读13次

tf.data模块包含一系列类,用于加载数据、操作数据并通过管道将数据传送给模型。本文主要介绍之前提到的iris_data.py中的train_input_fn函数。

0 train_input_fn定义

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Build the Iterator, and return the read end of the pipeline.
    return dataset.make_one_shot_iterator().get_next()

接下来对这个函数进行简单介绍。

1. Arguments

该函数需要如下三个参数:

  • features:包含有原始输入特征的{"feature_name": array}字典或者DataFrame
  • labels:包含每个样本标签的数组
  • batch_size:表示所需批次大小的整数

2. Slices

最简单的情况,可以使用tf.data.Dataset.from_tensor_slices接收一个数组,并创建该数组的slices表示的tf.data.Dataset,这个方法根据数组的第一维创建对应的slices。比如mnist训练数据集的形状是(60000, 28, 28),通过from_tensor_slices返回的Dataset对象包含有60000个slices,其中每一个都是28*28的图像,具体代码如下所示:

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print mnist_ds

上述代码打印出如下内容,展示了数据集中slices的shapes以及types。需要注意的是,我们并不知道Dataset中的包含有多少个slices。

<TensorSliceDataset shapes: (28, 28), types: tf.uint8>

上述数据集表示了一个简单的数组,但是实际上Dataset可以表示更复杂的情况。如下所示,如果feature是一个标准的python字典,那么创建的Datasetshapestypes也将会被保留:

dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print dataset
<TensorSliceDataset

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}
>

同样的在之前提到的train_input_fn中,我们传递的是一个(dict(features), labels)这样的数据机,那么创建的Dataset同样会保留其结构信息,如下所示:

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

3 manipulation

当前创建的Dataset会按固定顺序迭代,并且一次仅生成一个元素。在它被用于训练之前,还需要其他的操作。tf.data.Dataset类提供了一系列方法来处理数据并生成后续训练可用的数据。如下所示:

dataset = dataset.shuffle(1000).repeat().batch(batch_size)

shuffle方法使用一个固定的缓冲区,将Dataset中的slices进行随即化处理。这里将buffer_size设置的比Dataset中的slices数要大一些,可以保证数据可以完全被随机化处理(iris数据一共有150条样本)
repeat方法会在调用结束后重启Dataset,保证后续训练时这个数据集可以使用。
batch方法会收集样本,并将它们放在一起以创建批次(有时候使用样本进行训练是按照batch进行训练的,例如mini batch mini batch gradient descent优化算法),这为Dataset的shapes增加了一个维度。如下代码对之前的mnist Dataset使用batch方法,生成100个批次的数据,每一个批次都是包含有多个slices,其中每个slices都是一个28*28的图像数据。

print mnist_ds.batch(100)
<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>

需要注意的是,Dataset中第一维shapes是不确定的,因为最后一个批次所具有的slices数量是不确定的。
train_input_fn中,经过批处理之后,Dataset的结构如下所示:

print dataset
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

4 return

train_input_fn中返回的Dataset包含的是(feature_dict, labels)对。在后续trainevaluate使用的都是这种结构,但是在predictlabels被省略了。

相关文章

  • 4_datasets_quickstart

    tf.data模块包含一系列类,用于加载数据、操作数据并通过管道将数据传送给模型。本文主要介绍之前提到的iris_...

网友评论

      本文标题:4_datasets_quickstart

      本文链接:https://www.haomeiwen.com/subject/grrreftx.html