tf.data API 能够让你创建简单的,可复用的,输入流程。例如,在一个分布式文件系统下,一个图片模型可能要聚集数据从文件中,对每张图片使用random perturbations, 然后就可以随机地合并选择的图片到一个批次中进行训练。同时,对于一个文本模型有可能涉及提取symbols从原文本数据中,然后将它们转换为embedding identifiers并使用lookup table, 最后一起批次处理不同长度的序列。tf.data API 使得处理大量的数据,不同数据格式,和复杂的变化,变得非常容易。
tf.data API 介绍了两个新的抽象:
- tf.data.Dataset
一个dataset
表示一个元素的序列,在每个元素里包含一个或者多个Tensor对象。 举个栗子,在一个图片流程管道里,一个元素可能是一个单一的训练样本含有一对tensors表示图片的数据和标签。这里有两个不同的方法去创建一个dataset
- Creating a source (e.g. Dataset.from_tensor_slices()) constructs a dataset from one or more tf.Tensor objects. (这个是用于创建)
- Applying a transformation (e.g. Dataset.batch()) constructs a dataset from one or more tf.data.Dataset objects. (这个是用于应用)
- tf.data.Iterator
一个iterator提供了主要的方法从dataset中抽取元素。 这个操作通过执行 Iterator.get_next() yield dataset里的下个元素,并且典型地充当一个接口在输入流程管道和你的模型之间。一个最简单的iterator是“one-shot iterator”,它是与一个特别的dataset有关系,并迭代整个dataset一次。对于更多的复杂的用法,Iterator.initializer 能够使你重新初始化和参数化一个iterator对于不同的datasets。例如迭代训练和验证数据多次在同一个程序里。
基本机制
想要开始一个输入流程管道,你必须定义一个source。 例如在内存里,将一些tensors构建成一个 Dataset。你可以使用 tf.data.Dataset.from_tensors() 或者 tf.data.Dataset.from_tensor_slices()。或者,如果你的输入数据保存在本地硬盘里,而且是TFRecord格式(强力推荐),又可以构建一个 tf.data.TFRecordDataset。
一旦你有了Dataset对象,你可以将它转换成一个新的dataset通过链接方法调用tf.data.Dataset 对象。例如,你可以使用每个元素转化,像 Dataset.map(),应用一个方法对于每一元素,或者多个元素转换,像 Dataset.batch()
最通常的方法去消耗数值从一个dataset里是使用一个 iterator 对象提供一次访问对于dataset里的一个元素。(例如:调用Dataset.make_one_shot_iterator())。一个 tf.data.Iterator 提供两个操作:Iterator.initializer,它能够使你初始化或者重新初始化迭代器的状态;Iterator.get_next() , 它返还tf.Tensor对象,其符合于符号的下一个元素。根据你情况,有可能选择一个不同类型的迭代器,下面将会列出所有选项。
数据集结构
一个数据集包含的元素有相同的结构。一个元素有一个或者多个 tf.Tensor 对象,被称为组成部分。每一个组成部分有一个tf.DType 代表元素的类型在tensor中,还有一个 tf.TensorShape 代表每一个元素的静态形状。Dataset.output_types 和 Dataset.output_shapes的属性允许你检查一个数据集元素的组成部分的推断数据类型和形状。
创建一个迭代器
一旦建立一个dataset来表示你的输入数据,下一步就是创建一个 Iterator 来访问元素。tf.data API当前支持下列迭代器。
- one-shot
- initializable
- reinitializable
-
feedable
一个 one-shot 迭代器最简单的形式,通过一个数据集,它只支持迭代一次,不需要显性的初始化。One-shot 迭代器解决几乎所有出现基于排队输入的流程管道情况,但是他们不支持参数化。使用 Dataset.range() 的一个例子:
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
一个 initializable 迭代器在使用前需要你去运行一个显性的 iterator.initializer。作为这一不方便的交换,它使你能够参数化数据集的定义,使用一个或者多个 tf.placeholder() 张量,他们可以被给进当你初始化迭代器的时候。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
一个 **reinitializable ** 迭代器可以被从多个不同的dataset对象进行初始化。例如,你可能有一个训练输入流程管道,它需要随机打乱来挑选输入图片来提高模型泛化能力,同时一个验证输入流程管道进行评估未修改数据的预测结果。这些流程管道将典型地使用不同的 dataset 对象,但有相同的数据结构。
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
一个 feedable 迭代器可以和 tf.placeholder
一起被用于选择什么 迭代器
将会用到在每一次的 tf.Session.run
调用,通过我们熟悉的 feed_dict
机制。它同 reinitializable 迭代器一样提供相同的功能,但是它不需要你去初始化迭代器从一开始的数据集当你在迭代器之间切换时。例如,使用同一个训练和言行样本时,你可以使用 tf.data.Iterator.from_string_handle' 去定义一个
feedable` 迭代器,允许你在两个数据集之间切换。
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
从迭代器中消耗数值
Iterator.get_next()
方法返还一个或多个 tf.Tensor
对象,它们对应着迭代器的象征的下一个元素。每一次这些张量都会被评估,它们获得下一个元素的数值在下面的数据集。
如果迭代器循环到数据集的最后,继续执行 Iterator.get_next()
将会报错 tf.errors.OutOfRangeError
。在这一点之后,迭代器将是一个不可用状态,你必须再次初始化它,如果你想进一步使用。
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)
sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
一个常用模式是包装 “training loop” 在 try-except
中:
sess.run(iterator.initializer)
while True:
try:
sess.run(result)
except tf.errors.OutOfRangeError:
break
如果数据集的每个元素有一个嵌套的结构,那 Iterator.get_next()
返还的值将是一个或者多个 tf.Tensor
对象且是同样的嵌套结构。
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
保存迭代器状态
tf.contrib.data.make_saveable_from_iterator
方法将一个迭代器创建为一个 SaveableObject
,他可以被用于保存和还原当前迭代器的状态。一个可保存的对象可以被添加到 tf.train.Saver
变量列表中或者是 tf.GraphKeys.SAVEABLE_OBJECTS
集合中。
# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
if should_checkpoint:
saver.save(path_to_checkpoint)
# Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)
读取输入数据
处理 NumPy 数组
假如你的所有输入数据在内存里,那把它们创建成一个数据集的最简单方式是用
Dataset.from_tensor_slices()
转换它们为 tf.Tensor
对象。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
注意以上的代码片段将会嵌入 features
和 labels
数组到你的 TensorFlow graph 作为 tf.constant()
运算。这对小数据集的表现效果很好,但是浪费内存,因为数组将会被复制多次,对于 tf.GraphDef
protocal buffer 可以运行2GB的限制。
另一种方式,你可以敌营一个数据集根据 tf.placeholder()
张量,然后喂进 NumPy 数组,当你在一个数据集上初始化一个 Iterator
。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
处理 TFRecord 数据
tf.data
API 支持不同的文件格式所以你可以处理较大的数据集不需要在内存中。例如,TFRecord 文件格式是许多TensorFlow应用用作训练数据的一个简单的面向记录的二进制格式。tf.data.TFRecordDataset
类使你能够流遍一个或多个 TFRecord 文件的内容,作为输入流程管道的一部分。
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
参数 filenames
对于 tf.data.TFRecordDataset
初始化器可以是字符串,字符串的列表,或者字符串的 tf.Tensor
。因此如果你有两个文件集想要训练和验证,你可以使用 tf.placeholder(tf.string)
来代表文件名,然后初始化迭代器从合适的文件名:
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.
# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
处理 text 数据
许多数据集被分布在一个或者多个文本文件中。tf.data.TextLineDataset
提供一个简单的方法来提取所有行数据从一个或多个文本文佳。给与一个或多个文件名, 一个 TextLineDataset
将会处理这些文件的每一行的字符串数值元素。像一个 TFRecordDataset,
TextLineDataset接受
filenames作为一个
tf.Tensor,所以你可以参数化它通过传入一个
tf.placeholder(tf.string)`。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
默认的,一个 TextLineDataset
生成每个文件的每行数据,它也许不尽人意,例如法国文件始于一个头行,或者包含一些注释。这些行使用 Dataset.skip()
进行移除 和 Dataset.filter()
进行转换。为了应用这些转换对于每个文件各自地,我们使用 Dataset.flat_map()
去创建一个嵌套 Dataset
对于每个文件。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
lambda filename: (
tf.data.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
预处理数据通过使用 Dataset.map()
Dataset.map(f)
变换使用一个已给的方法 f
应用到输入数据集的每个元素,处理成一个新的数据集。 基于 map() function
这是一个常用于列表或者其他结构在函数编程语言中。
解析 tf.Example
protocol buffer messages
许多输入流程管道抓取 tf.train.Example
protocol buffer messages 从一个 TFRecord-format 文件(用 tf.python_io.TFRecordWriter
来写的)。每条 tf.train.Example
记录包含一个或者多个 "特征",输入流程管道典型的转换这些特征为张量。
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
解码图片数据,调整大小
当训练一个神经网络在真实世界图片数据时,有必要把不同图片大小调整为通常的大小,以便他们可能批处理成一个固定的大小。
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
应用任意的 Python logic 通过 tf.py_func()
对于性能的原因,我们鼓励你在任何时候使用 TensorFlow 操作来预处理你的数据。但是,有些时候是有用的来调用外部 Python 库当解析你的输入数据的时候。调用 tf.py_func()
操作在一个 Dataset.map()
转换。
import cv2
# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
return image_decoded, label
# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
image_decoded.set_shape([None, None, None])
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tuple(tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)
批处理数据集元素
简单的批处理
最简单的批处理形式是堆一个数据集的n个连续的元素为一个单一的元素。Dataset.batch()
转换准确地做这项工作,同 tf.stack()
操作符的约束相同,应用到元素的每个组成部分。
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
用 padding 批处理张量
以上的方法对张量具有相同的大小有用。但是,许多模型,例如序列模型,处理输入的数据可以有不同的大小。为了解决这种情况,Dataset.padded_batch()
转换允许你批处理不同形状的张量通过指定一个或者多个需要padded的维度。
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
# [5, 5, 5, 5, 5, 0, 0],
# [6, 6, 6, 6, 6, 6, 0],
# [7, 7, 7, 7, 7, 7, 7]]
Dataset.padded_batch()
转换让你设定不同的 padding 对于每个组成部分的每个维度,又或者是可变的长度或者固定的长度。也可能复写 padding 的值,默认为 0。
训练流程
处理多个 epochs
tf.data
API 提供两个主要的方式来处理同样数据的多个 epochs。
最简单的方法来迭代一个数据集多次 epochs 是使用 Dataset.repeat()
变换。例如,创建一个重复自身输入为 10 epochs 的数据集
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)
应用 Dataset.repeat()
转换不带参数时,将会无限地重复输入。Dataset.repeat()
变换连接它的参数不需要标志一个 epoch 的结束,和下一个 epoch 的开始。
如果你想要接受每个 epoch 结束的信号,你可以写一个训练循环来捕捉数据集结束的 tf.errors.OutOfRangeError
。在那一节点,你可以收集一些统计数据在每一个 epoch 中。
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Compute for 100 epochs.
for _ in range(100):
sess.run(iterator.initializer)
while True:
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
break
# [Perform end-of-epoch calculations here.]
随机打乱输入数据
Dataset.shuffle()
变换随机地打乱输入数据集通过一个相似的算法 tf.RandomShuffleQueue
:它维护一个固定的大小缓存和chooses the next element uniformly at random from that buffer。
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
使用高级APIs
tf.train.MonitoredTrainingSession
API 简化运行 TensorFlow 的许多方面在一个分布式的设置。MonitoredTrainingSession
使用 tf.errors.OutOfRangeError
来表示训练已经完成,所以使用它通过 ‘tf.data’ API, 建议使用 Dataset.make_one_shot_iterator()
。例如:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)
with tf.train.MonitoredTrainingSession(...) as sess:
while not sess.should_stop():
sess.run(training_op)
为了使用一个 Dataset
在 tf.estimator.Estimator
的 input_fn
中,建议使用 Dataset.make_one_shot_iterator()
。例如:
def dataset_input_fn():
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.
def parser(record):
keys_to_features = {
"image_data": tf.FixedLenFeature((), tf.string, default_value=""),
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
# Perform additional preprocessing on the parsed data.
image = tf.image.decode_jpeg(parsed["image_data"])
image = tf.reshape(image, [299, 299, 1])
label = tf.cast(parsed["label"], tf.int32)
return {"image_data": image, "date_time": parsed["date_time"]}, label
# Use `Dataset.map()` to build a pair of a feature dictionary and a label
# tensor for each example.
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
features, labels = iterator.get_next()
return features, labels
2018,03,29 更新
网友评论