保存模型并不限于在训练之后,在训练之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况。我们自然希望能够将辛苦得到的中间参数保留下来,否则下次又要重新开始。这种在训练中保存模型,习惯上称之为保存检查点。
1、线性回归例子
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#训练数据
train_x = np.linspace(-1,1,100)
train_y = 2* train_x + np.random.randn(*train_x.shape)*0.3
tf.reset_default_graph()
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
w = tf.Variable(tf.random_normal([1]),name='weight')
b = tf.Variable(tf.zeros([1]),name='bias')
predict = tf.multiply(w,x)+b
cost = tf.reduce_mean(tf.square(y-predict))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
init = tf.global_variables_initializer()
training_epochs = 200
display_step= 2
2、保存检查点
# max_to_keep 保存的检查点个数
saver = tf.train.Saver(max_to_keep=2)
savedir = 'log/'
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
# for(x,y) in zip(train_x,train_y):
sess.run(optimizer,feed_dict={x:train_x,y:train_y})
loss = sess.run(cost,feed_dict={x:train_x,y:train_y})
print('epoch:',epoch,'loss',loss)
#保存检查点
saver.save(sess,savedir+'linemodel.cpkt',global_step=epoch)
print('Finish')
plt.plot(train_x,train_x,color='green')
plt.plot(train_x,sess.run(w)*train_x+sess.run(b),color='red')
plt.legend()
plt.show()
log文件夹下生成的文件
xia3、另起一个session载入保存的检查点
with tf.Session() as sess2:
sess2.run(init)
saver.restore(sess2,savedir+'linemodel.cpkt-'+str(198))
print(sess2.run(w))
print(10*sess2.run(w)+sess2.run(b))
网友评论