tf.data
模块包含一系列类,用于加载数据、操作数据并通过管道将数据传送给模型。本文主要介绍之前提到的iris_data.py
中的train_input_fn
函数。
0 train_input_fn定义
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Build the Iterator, and return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
接下来对这个函数进行简单介绍。
1. Arguments
该函数需要如下三个参数:
-
features
:包含有原始输入特征的{"feature_name": array}字典或者DataFrame
-
labels
:包含每个样本标签的数组 -
batch_size
:表示所需批次大小的整数
2. Slices
最简单的情况,可以使用tf.data.Dataset.from_tensor_slices
接收一个数组,并创建该数组的slices表示的tf.data.Dataset
,这个方法根据数组的第一维创建对应的slices。比如mnist训练数据集的形状是(60000, 28, 28)
,通过from_tensor_slices
返回的Dataset
对象包含有60000个slices,其中每一个都是28*28的图像,具体代码如下所示:
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print mnist_ds
上述代码打印出如下内容,展示了数据集中slices的shapes以及types。需要注意的是,我们并不知道Dataset中的包含有多少个slices。
<TensorSliceDataset shapes: (28, 28), types: tf.uint8>
上述数据集表示了一个简单的数组,但是实际上Dataset可以表示更复杂的情况。如下所示,如果feature
是一个标准的python字典,那么创建的Dataset
的shapes
和types
也将会被保留:
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print dataset
<TensorSliceDataset
shapes: {
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
types: {
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64}
>
同样的在之前提到的train_input_fn
中,我们传递的是一个(dict(features), labels)
这样的数据机,那么创建的Dataset
同样会保留其结构信息,如下所示:
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
()),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
3 manipulation
当前创建的Dataset
会按固定顺序迭代,并且一次仅生成一个元素。在它被用于训练之前,还需要其他的操作。tf.data.Dataset
类提供了一系列方法来处理数据并生成后续训练可用的数据。如下所示:
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
shuffle
方法使用一个固定的缓冲区,将Dataset
中的slices进行随即化处理。这里将buffer_size
设置的比Dataset
中的slices数要大一些,可以保证数据可以完全被随机化处理(iris数据一共有150条样本)
repeat
方法会在调用结束后重启Dataset
,保证后续训练时这个数据集可以使用。
batch
方法会收集样本,并将它们放在一起以创建批次(有时候使用样本进行训练是按照batch进行训练的,例如mini batch mini batch gradient descent优化算法),这为Dataset
的shapes增加了一个维度。如下代码对之前的mnist Dataset
使用batch
方法,生成100个批次的数据,每一个批次都是包含有多个slices,其中每个slices都是一个28*28的图像数据。
print mnist_ds.batch(100)
<BatchDataset
shapes: (?, 28, 28),
types: tf.uint8>
需要注意的是,Dataset
中第一维shapes是不确定的,因为最后一个批次所具有的slices数量是不确定的。
在train_input_fn
中,经过批处理之后,Dataset
的结构如下所示:
print dataset
<TensorSliceDataset
shapes: (
{
SepalLength: (?,), PetalWidth: (?,),
PetalLength: (?,), SepalWidth: (?,)},
(?,)),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
4 return
在train_input_fn
中返回的Dataset
包含的是(feature_dict, labels)
对。在后续train
、evaluate
使用的都是这种结构,但是在predict
中labels
被省略了。
网友评论