美文网首页TensorFlow
tensorflow中检查点的使用

tensorflow中检查点的使用

作者: 上行彩虹人 | 来源:发表于2018-11-12 22:39 被阅读3次

    保存模型并不限于在训练之后,在训练之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况。我们自然希望能够将辛苦得到的中间参数保留下来,否则下次又要重新开始。这种在训练中保存模型,习惯上称之为保存检查点。
    1、线性回归例子

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    #训练数据
    train_x = np.linspace(-1,1,100)
    train_y = 2* train_x + np.random.randn(*train_x.shape)*0.3
    
    tf.reset_default_graph()
    
    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    
    w = tf.Variable(tf.random_normal([1]),name='weight')
    b = tf.Variable(tf.zeros([1]),name='bias')
    predict = tf.multiply(w,x)+b
    
    cost = tf.reduce_mean(tf.square(y-predict))
    optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
    
    init = tf.global_variables_initializer()
    
    training_epochs = 200
    display_step= 2
    

    2、保存检查点

    #                    max_to_keep 保存的检查点个数
    saver = tf.train.Saver(max_to_keep=2)
    savedir = 'log/'
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(training_epochs):
            # for(x,y)  in zip(train_x,train_y):
            sess.run(optimizer,feed_dict={x:train_x,y:train_y})
            loss = sess.run(cost,feed_dict={x:train_x,y:train_y})
            print('epoch:',epoch,'loss',loss)
            #保存检查点
            saver.save(sess,savedir+'linemodel.cpkt',global_step=epoch)
        print('Finish')
    
        plt.plot(train_x,train_x,color='green')
        plt.plot(train_x,sess.run(w)*train_x+sess.run(b),color='red')
        plt.legend()
        plt.show()
    

    log文件夹下生成的文件

    xia

    3、另起一个session载入保存的检查点

    with tf.Session() as sess2:
        sess2.run(init)
        saver.restore(sess2,savedir+'linemodel.cpkt-'+str(198))
        print(sess2.run(w))
        print(10*sess2.run(w)+sess2.run(b))
    

    相关文章

      网友评论

        本文标题:tensorflow中检查点的使用

        本文链接:https://www.haomeiwen.com/subject/spuhfqtx.html