一:变量管理
#在生成上下文管理器时,将参数reuse设置为True。
#这样tf.get_variable函数将直接获取已经声明的变量.
#可是,若该命名空间还未创建变量v1将报错。
#相反的,若将参数reuse设置为False或None,tf.get_variable将创建新变量,若同名变量已经存在,将报错
with tf.variable_scope('foo', reuse=True):
v1 = tf.get_variable('v', shape=[1]) #获取已经声明的变量v1
二:模型持久化方法
1.保存模型
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(1.0, shape=[1], name='v2'))
result = v1 + v2
init_op = tf.initialize_all_variables()
#用于保存类型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
#将模型保存到该文件model.ckpt
saver.save(sess, '/path/to/model/model.ckpt')
2.加载已保存的模型
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1')) #声明变量的名称必须与保存变量的名称相同,否则出错
v2 = tf.Variable(tf.constant(1.0, shape=[1], name='v2'))
result = v1 + v2
#用于保存类型
saver = tf.train.Saver()
with tf.Session() as sess:
#加载已保存的模型,并通过已保存的变量值来计算加法
saver.restore(sess, '/path/to/model/model.ckpt')
sess.run(result)
网友评论