美文网首页
【学习tensorflow2】有用的API汇总

【学习tensorflow2】有用的API汇总

作者: WILeroy | 来源:发表于2020-08-21 12:03 被阅读0次
  1. 动态显存分配
from tensorflow.compat.v1 import ConfigProto, InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
  1. 随机数种子: 为
tf.random.set_seed(2317)
  1. 混合精度
opt = Adam()
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
model.compile(optimizer=opt, loss="...")
  1. 中断训练与继续训练
reloaded = False
# 参数为键值对, 如global_epoch=global_epoch, 等式左边是key(自行定义), 右边是value(tf的变量, 模型, 优化器等).
checkpoint = tf.train.Checkpoint(global_epoch=global_epoch, model=model)
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if reloaded:
    checkpoint.restore(manager.latest_checkpoint)
while True:
    # Train.
    manager.save()
  1. 日志可视化
log_writer = tf.summary.create_file_writer(log_dir)
def write_log(l, name):
    with log_writer.as_default():
        tf.summary.scalar(name, l, step=global_epoch)
    log_writer.flush()
# 使用tensorboard --logdir [log_dir]可视化日志.

相关文章

网友评论

      本文标题:【学习tensorflow2】有用的API汇总

      本文链接:https://www.haomeiwen.com/subject/dxrujktx.html