可以参考官方的文件.
tf.data
提供了一整套复杂的数据输入和使用的方法. 比如图片的数据可能包含图片的数据(image)和它的标签(label).
1. 创建
有两种创建的方式:
- 从内存中获取呢使用函数
tf.data.Dataset.from_tensors()
ortf.data.Dataset.from_tensor_slices()
. - 从TFRecord 格式的文件中
tf.data.TFRecordDataset()
.
1.1 .from_tensors
from_tensors
会将数据压缩成一组元素.
1.2 .from_tensor_slices
from_tensor_slices
会将数据压缩, 然后以他们的第一个维数进行分组.
dataset0 = tf.data.Dataset.from_tensors([8, 3, 0, 8, 2, 1])
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
for elem in dataset:
print(elem.numpy())
from_tensors
处理后的dataset0
会包含一个元素, 而from_tensor_slices
处理后会包含多个元素.
使用.numpy()
的方式转换成NumPy. 这里使用遍历(for
)的方式打印.
也可以使用迭代器
it = iter(dataset)
print(next(it).numpy())
可以使用reduce
遍历所有的元素, 生成一个元素, 如
print(dataset.reduce(0, lambda state, value: state + value).numpy())
2. 数据结构
可以使用Dataset.element_spec
查看数据类型
Dataset.map()
和Dataset.filter()
可以被执行用于所有的元素.
3. 批处理所有的元素
使用Dataset.batch()
进行批处理
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
# dataset 为 ZipDataset 类似元组
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
# 将其4个一组打包
batched_dataset = dataset.batch(4)
# .take(4)选择前4个
for batch in batched_dataset.take(4):
print([arr.numpy() for arr in batch])
使用drop_remainder
可以舍弃剩余项
batched_dataset = dataset.batch(7, drop_remainder=True)
网友评论