为什么要用tf.data API
- 可以以快速且可扩展的方式加载和预处理数据
- 大型数据集 vs 小型数据集(用numpy即可)
- 支持分布式训练
数据增强(Data augmentation) Keras提供两种实现数据增强的方式:
- 将数据增强层在模型中实现,代码如下所示:
model = tf.keras.Sequential([
resize_and_rescale,
data_augmentation,
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
# Rest of your model
])
这种方式的好处是,数据增强可以在GPU中计算,训练速度会更快。使用TensorFlow做部署时,在推理模式下,数据增强层会自动变为恒等输出。当使用其它软件工具部署模型时,例如:OpenVINO,则内嵌入模型的数据增强层行为变得未知
- 不将数据增强功能集成到模型里面,仅仅作用于训练数据,代码如下所示:
aug_ds = train_ds.map(
lambda x, y: (resize_and_rescale(x, training=True), y))
这种方式的好处是,数据增强在CPU中异步执行,可以更好的利用 Dataset.prefetch的优势;另外,数据增强与模型分离,方便模型用第三工具进行部署。本文选用第二种方式实现数据增强
完整范例代码:
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
layers.experimental.preprocessing.RandomRotation(0.2),
])
train_dataset = train_dataset.map(lambda x,y: (data_augmentation(x, training=True), y), num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(1000).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
网友评论