美文网首页
TensorFlow学习笔记:Saver与Restore

TensorFlow学习笔记:Saver与Restore

作者: DexterLei | 来源:发表于2017-10-02 14:50 被阅读4344次

    tf.train.Saver

    注意:saver()与restore()只是保存了session中的相关变量对应的值,并不涉及模型的结构。


    Saver

    Defined in tensorflow/python/training/saver.py.
    See the guides: Exporting and Importing a MetaGraph > Exporting a Complete Model to MetaGraph, Exporting and Importing a MetaGraph, Variables > Saving and Restoring Variables
    Saves and restores variables.
    See Variables for an overview of variables, saving and restoring.

    Saver的作用是将我们训练好的模型的参数保存下来,以便下一次继续用于训练或测试;Restore则是将训练好的参数提取出来。Saver类训练完后,是以checkpoints文件形式保存。提取的时候也是从checkpoints文件中恢复变量。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

    一般地,Saver会自动的管理Checkpoints文件。我们可以指定保存最近的N个Checkpoints文件,当然每一步都保存ckpt文件也是可以的,只是没必要,费存储空间。

    • saver()可以选择global_step参数来为ckpt文件名添加数字标记:
    saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
    ...
    saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
    
    • max_to_keep参数定义saver()将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。
    • keep_checkpoint_every_n_hoursmax_to_keep类似,定义每n小时保存一个ckpt文件。
    ...
    # Create a saver.
    saver = tf.train.Saver(...variables...)
    # Launch the graph and train, saving the model every 1,000 steps.
    sess = tf.Session()
    for step in xrange(1000000):
        sess.run(..training_op..)
        if step % 1000 == 0:
            # Append the step number to the checkpoint name:
            saver.save(sess, 'my-model', global_step=step)
    

    一个简单的例子:

    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    import tensorflow as tf
    import time
    
    time.clock()
    
    x = tf.placeholder(tf.float32 ,[None, 784])
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x,W) + b)
    
    # 为了计算交叉熵,我们需要添加一个新的占位符用于输入正确值。
    y_ = tf.placeholder(tf.float32, [None,10])
    cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    # 在此,我们要求TF使用梯度下降算法,并以0.01的学习速率最小化交叉熵。
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    
    # 创建Saver节点,并设置自动保存最近n=1次模型
    saver = tf.train.Saver(max_to_keep=1)
    saver_max_acc = 0 
    for i in range(100):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})
        correct_prediction = tf.equal(tf.argmax(y,1), tf.arg_max(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
        if (i+1)%10 == 0:
            print('{0:0>2d}:{1:.4f}'.format((i+1),accuracy.eval(session=sess, feed_dict={x: mnist.test.images, y_: mnist.test.labels})))
        # 添加判断语句,选择保存精度最高的模型
        if accuracy > saver_max_acc:
            saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
            saver_max_acc = accuracy
    sess.close()
    print(time.clock())
    

    Restore

    restore(sess, save_path)
    # sess: A Session to use to restore the parameters.
    # save_path: Path where parameters were previously saved.
    
    • sess: 保存参数的会话。
    • save_path: 保存参数的路径。
    • 当从文件中恢复变量时,不需要事先对他们进行初始化,因为“恢复”自身就是一种初始化变量的方法。
    • 可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
    model_file=tf.train.latest_checkpoint('ckpt/')
    saver.restore(sess,model_file)
    

    参考资料:

    1. tensorflow 1.0 学习:模型的保存与恢复(Saver)
    2. 莫烦 Tensorflow 19 Saver 保存读取 (神经网络 教学教程tutorial)
    3. TensorFlow手把手入门之 — TensorFlow保存还原模型的正确方式,Saver的save和restore方法,亲测可用

    相关文章

      网友评论

          本文标题:TensorFlow学习笔记:Saver与Restore

          本文链接:https://www.haomeiwen.com/subject/sictyxtx.html