Keras迁移学习完整范例:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# 基于预训练模型Xception的特征提取层,创建新的模型
# The default input image size for the Xception model is 299x299
base_model = tf.keras.applications.xception.Xception(
include_top=False,
weights="imagenet",
input_shape=(299,299,3)
)
# 冻结该模型
base_model.trainable = False
# 载入训练和测试数据集
# 准备数据
TRAIN_DATASET_PATH = r"D:\nn_tf\cats_vs_dogs\train" #训练数据集路径
TEST_DATASET_PATH = r"D:\nn_tf\cats_vs_dogs\test" #测试数据集路径
batch_size = 32
image_size = (299,299)
# 训练数据集
train_dataset = keras.preprocessing.image_dataset_from_directory(
TRAIN_DATASET_PATH,
validation_split=0.2,
image_size=image_size,
seed=1337,
subset='training',
batch_size=batch_size
)
# 验证数据集
val_dataset = keras.preprocessing.image_dataset_from_directory(
TRAIN_DATASET_PATH,
validation_split=0.2,
subset="validation",
seed=1337,
image_size=image_size,
batch_size=batch_size,
)
# 测试数据集
test_dataset = keras.preprocessing.image_dataset_from_directory(
TEST_DATASET_PATH,
image_size=image_size,
shuffle=False,
batch_size=batch_size
)
# 查看数据形状
for image, label in train_dataset.take(1):
print(image.shape)
# 设计数据增强算法
data_augmentation = tf.keras.Sequential([
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.2),
])
# 将数据增强作用到训练数据集上
train_dataset = train_dataset.map(lambda x, y: (data_augmentation(x, training=True), y))
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
# 查看数据增强后的数据形状
for image, label in train_dataset.take(1):
print(image.shape)
# 添加Xception model的预处理层
xception_preprocess = keras.applications.xception.preprocess_input
# 创建带有预处理的新模型
inputs = keras.Input(shape=(299,299,3), name="image_input")
x = xception_preprocess(inputs)
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x) #将特征值展平成向量
outputs = keras.layers.Dense(1, name="precisions")(x)
model_with_prepocess = keras.Model(inputs, outputs, name="model_with_prepocess")
# 编译模型
model_with_prepocess.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
# 训练模型
model_with_prepocess.fit(train_dataset, epochs=10, validation_data=val_dataset)
# 在测试数据集上测试模型
loss, acc = model_with_prepocess.evaluate(test_dataset)
print(f"Loss is {loss}; Accuracy is {acc} in Test Dataset")

网友评论