美文网首页TensorFlow操作
TensorFlow 加载部分变量

TensorFlow 加载部分变量

作者: 翻开日记 | 来源:发表于2018-07-13 11:01 被阅读0次
    for v in tf.global_variables():
        if 'global_step' in v.name:
            var2.append(v) # 重置global step 调整学习率重新训练
        else:
            var1.append(v)
    var1 = tf.train.Saver(var1)
    ...
    """Restore Model"""
    save_file = tf.train.latest_checkpoint(hparams_at.train_dir)
    if save_file:
        print(save_file)
        step = int(save_file.split('ckpt-')[-1]) + 1
        var1.restore(sess, save_file)
        sess.run(tf.variables_initializer([model.global_step]))
        print("Go on")
    else:
        step = 0
        sess.run(tf.global_variables_initializer())
        print("Begin")
    
    

    相关文章

      网友评论

        本文标题:TensorFlow 加载部分变量

        本文链接:https://www.haomeiwen.com/subject/mgljpftx.html