本文是对官方文档 的学习笔记。
Callback 在 TF2 中是一个非常强大的工具。 例如可以支持 TensorBoard tf.keras.callbacks.TensorBoard
或者定期的保存 Checkpoint.
概览
所有的Callback 函数都继承自keras.callbacks.Callback
类。 以下函数可以接受 Callback。
Callback 方法
从下列方法可以看出 Callback 的调用时机都是在一些特定的时间点,Callback就是要在这些时间点完成一些任务。
全局方法
-
on_(train|test|predict)_begin(self, logs=None)
Called at the beginning of fit/evaluate/predict. -
on_(train|test|predict)_end(self, logs=None)
Called at the end of fit/evaluate/predict.
Batch-level methods for training/testing/predicting
-
on_(train|test|predict)_batch_begin(self, batch, logs=None)
Called right before processing a batch during training/testing/predicting. -
on_(train|test|predict)_batch_end(self, batch, logs=None)
Called at the end of training/testing/predicting a batch. Within this method, logs is a dict containing the metrics results.
Epoch-level methods (只针对训练)
-
on_epoch_begin(self, epoch, logs=None)
Called at the beginning of an epoch during training. -
on_epoch_end(self, epoch, logs=None)
Called at the end of an epoch during training.
例子
# Define the Keras model to add callbacks to
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=784))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)
res = model.evaluate(
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Log Dict
Log Dict 包含了训练中各种 Loss 和 Metrics 值, 他们会在 batch 或者 epoch 之后更新。
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))
def on_test_batch_end(self, batch, logs=None):
print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))
def on_epoch_end(self, epoch, logs=None):
print(
"The average loss for epoch {} is {:7.2f} "
"and mean absolute error is {:7.2f}.".format(
epoch, logs["loss"], logs["mean_absolute_error"]
)
)
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=2,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
使用 self.model 属性
除了 log dict 以外, Callback 还可以获得当前执行的Model 对象, 并利用self.model 来获取该对象。
可以利用 self.model 做的事情:
- 通过 self.model.stop_training = True 停止训练
- 通过 self.model.optimizer 来修改优化器的超参数(例如: self.model.optimizer.learning_rate)
- 定期保存模型
- 在每个时期结束时,在一些测试样本上记录model.predict()的输出,以在训练期间用作检查。
- 在每个时期结束时提取中间特征的可视化,以监视模型随时间推移正在学习的内容。
Keras Callback 例子
Early Stopping
第一个示例显示了如何通过设置属性self.model.stop_training(boolean)来创建回调,该回调将在达到最小损失时停止训练。
更全面的实现,可以参考 : tf.keras.callbacks.EarlyStopping
import numpy as np
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=30,
verbose=0,
callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
调整学习率(learning rate)
这个例子演示了如何动态的调整学习率
callbacks.LearningRateScheduler
提供了一个更加完善的方案。
class CustomLearningRateScheduler(keras.callbacks.Callback):
"""Learning rate scheduler which sets the learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index
(integer, indexed from 0) and current learning rate
as inputs and returns a new learning rate as output (float).
"""
def __init__(self, schedule):
super(CustomLearningRateScheduler, self).__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "lr"):
raise ValueError('Optimizer must have a "lr" attribute.')
# Get the current learning rate from model's optimizer.
lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))
LR_SCHEDULE = [
# (epoch to start, learning rate) tuples
(3, 0.05),
(6, 0.01),
(9, 0.005),
(12, 0.001),
]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=15,
verbose=0,
callbacks=[
LossAndErrorPrintingCallback(),
CustomLearningRateScheduler(lr_schedule),
],
)
TF2 内建 Keras Callbacks
这点要注意,TF2 已经把常见Callbacks 写好了, 对于常见的功能就没必要自己重新发明轮子了。 需要Callback的时候, 先查阅一下 [Callback 参考]
(https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/)
这里给出个TF2.4 的Callback 列表
class BaseLogger
: Callback that accumulates epoch averages of metrics.
class CSVLogger
: Callback that streams epoch results to a CSV file.
class Callback
: Abstract base class used to build new callbacks.
class CallbackList
: Container abstracting a list of callbacks.
class EarlyStopping
: Stop training when a monitored metric has stopped improving.
class History
: Callback that records events into a History
object.
class LambdaCallback
: Callback for creating simple, custom callbacks on-the-fly.
class LearningRateScheduler
: Learning rate scheduler.
class ModelCheckpoint
: Callback to save the Keras model or model weights at some frequency.
class ProgbarLogger
: Callback that prints metrics to stdout.
class ReduceLROnPlateau
: Reduce learning rate when a metric has stopped improving.
class RemoteMonitor
: Callback used to stream events to a server.
class TensorBoard
: Enable visualizations for TensorBoard.
class TerminateOnNaN
: Callback that terminates training when a NaN loss is encountered.
网友评论