- 动态显存分配
from tensorflow.compat.v1 import ConfigProto, InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
- 随机数种子: 为
tf.random.set_seed(2317)
- 混合精度
opt = Adam()
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
model.compile(optimizer=opt, loss="...")
- 中断训练与继续训练
reloaded = False
# 参数为键值对, 如global_epoch=global_epoch, 等式左边是key(自行定义), 右边是value(tf的变量, 模型, 优化器等).
checkpoint = tf.train.Checkpoint(global_epoch=global_epoch, model=model)
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if reloaded:
checkpoint.restore(manager.latest_checkpoint)
while True:
# Train.
manager.save()
- 日志可视化
log_writer = tf.summary.create_file_writer(log_dir)
def write_log(l, name):
with log_writer.as_default():
tf.summary.scalar(name, l, step=global_epoch)
log_writer.flush()
# 使用tensorboard --logdir [log_dir]可视化日志.
网友评论