美文网首页
Tf上保存变量

Tf上保存变量

作者: vincehxb | 来源:发表于2017-07-05 21:35 被阅读0次

1.定义文件的保存路径

ckpt_dir="./ckpt_dir"

    ifnotos.path.exists(ckpt_dir):

    os.makedirs(ckpt_dir)

2.定义一个全局变量

global_step=tf.Variable(0,name='global_step',trainable=False)

这个全局变量是保存文件和提取文件的标识,比如我现在要load什么时候保存的变量

3.定义saver方法

saver=tf.train.Saver()

注意任何变量定义在saver前面的都会被保存,在其后面定义的都不会被保存

4.保存变量

注意看前面定义的变量global_step,第一步给这个变量更新值(epoch),然后再保存。所以这个变量是以后load哪个文件的依据

global_step.assign(i).eval()#set and update(eval) global_step with index, i

saver.save(sess, ckpt_dir+"/model.ckpt",global_step=global_step)

5.load变量

ckpt=tf.train.get_checkpoint_state(ckpt_dir)

if ckpt and ckpt.model_checkpoint_path:

    print(ckpt.model_checkpoint_path)

    saver.restore(sess, ckpt.model_checkpoint_path)#restore all variables

相关文章

网友评论

      本文标题:Tf上保存变量

      本文链接:https://www.haomeiwen.com/subject/jzfchxtx.html