定义变量,初始化,一般初始化随机值,或者常值
weights = tf.Variable(tf.random_normal([784, 200],stddev=0.35),
name='weights')
from tensorflow.python.framework import ops
ops.reset_default_graph()
biases = tf.Variable(tf.zeros([200]), name='biases')
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
#print sess.run(weights)
保存变量
from tensorflow.python.framework import ops
#ops.reset_default_graph()
g1 = tf.Graph()
print g1
with g1.as_default():
# 由另一个变量初始化
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
name='weights')
w2 =tf.Variable(weights.initialized_value(), name='w2')
w_twice = tf.Variable(weights.initialized_value()*0.2,name='w_twice')
# 保存变量
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session(graph=g1) as sess:
sess.run(init_op)
print sess.run(weights)
save_path = saver.save(sess, '/tmp/model.ckpt')
print 'Model saved in file: ',save_path
恢复变量
#ops.reset_default_graph()
# 恢复变量
g2 = tf.Graph()
with g2.as_default():
weightss = tf.Variable(tf.zeros([784,200]),name='weights')
w_2 = tf.Variable(weightss, name='w2')
w_t = tf.Variable(weightss, name='w_twice')
print weightss.graph
saver = tf.train.Saver()
with tf.Session(graph=g2) as sess:
saver.restore(sess, '/tmp/model.ckpt')
#print sess.run(weightss)
# print sess.run(w_2)
print sess.run(w_t)
保存部分变量
from tensorflow.python.framework import ops
ops.reset_default_graph()
g1 = tf.Graph()
print g1
with g1.as_default():
# 由另一个变量初始化
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
name='weights')
w2 =tf.Variable(weights.initialized_value(), name='w2')
w_twice = tf.Variable(weights.initialized_value()*0.2,name='w_twice')
# 保存变量
init_op = tf.global_variables_initializer()
saver = tf.train.Saver({'my_w2':w2,"my_wt":w_twice})
with tf.Session(graph=g1) as sess:
sess.run(init_op)
print sess.run(weights)
save_path = saver.save(sess, '/tmp/model.ckpt')
print 'Model saved in file: ',save_path
恢复变量
g2 = tf.Graph()
with g2.as_default():
w_2 = tf.Variable(tf.zeros([784,200]), name='my_w2')
w_t = tf.Variable(tf.zeros([784,200]), name='my_wt')
#weightss = tf.Variable(tf.zeros([784,200]),name='my_weight')
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session(graph=g2) as sess:
sess.run(init_op)
saver.restore(sess, '/tmp/model.ckpt')
print sess.run(w_2)
网友评论