美文网首页LSTM深入理解tensorflowTensorflow
TensorFlow 模型保存和恢复示例

TensorFlow 模型保存和恢复示例

作者: 祝威廉 | 来源:发表于2017-05-31 17:14 被阅读2422次

    前言

    在之前一篇文章里:使用CNN+ Auto-Encoder 实现无监督Sentence Embedding (代码基于Tensorflow),训练完成后,encode的参数也就被训练好了,这个时候我们利用这些参数对数据进行编码处理,从而得到一个向量。

    保存模型

    如果回忆下,上次的模型基本是这样的:

    Input(段落) -> encoder -> encoder -> decoder -> decoder -> lost function (consine夹角)
    

    我需要用到的是第二个encoder,在Tensorflow里,所有的都是Tensor,因此给定输入,就可以通过tensor给出输出。训练的过程中,涉及到第二个encoder的代码如下:

    ....
    flattened = tf.reshape(conv_out, [-1, 51 * 128]) if USE_CNN else tf.reshape(input_x,
                                                                                    [-1, SEQUENCE_LENGTH * VOCAB_SIZE])
    
    encoder_op = encoder(flattened)
    ....
    

    我们真个训练过程其实是在tunning encoder的参数。现在我需要把encoder_op保留下来,供下次使用,这可以通过add_collection方法

    tf.add_to_collection('encoder_op', encoder_op)
    

    sess.run(tf.global_variables_initializer()) 之后,我们获取Saver对象:

    saver = tf.train.Saver()
    

    然后在迭代的过程中,比如每迭代五次就保存一次模型:

    if i %5 = 0: 
        saver.save(sess, MODEL_SAVE_DIR)
    

    恢复模型

    sess = tf.Session()
    ## 这里是恢复graph
    saver = tf.train.import_meta_graph(MODEL_SAVE_DIR + '/' + MODEL_NAME + '.meta')
    ## 这里是恢复各个权重参数
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_SAVE_DIR))
    
    
    sess.run(tf.global_variables_initializer())
    ## 获取输入的tensor
    input_x = tf.get_default_graph().get_tensor_by_name("input_x:0")
    ......
    
    x_in = result1[0:SEQUENCE_LENGTH]
    ## 获取到encoder_op
    encoder_op = tf.get_collection("encoder_op")[0]
    ## 给定数据,运行encoder_op 
    s = sess.run(encoder_op, feed_dict={input_x: [x_in]})
    

    具体的解释已经在代码中提及。这样我们就可以利用encoder_op对新数据进行编码了。

    完整的恢复模型参看:tensorflow_restore.py

    额外的话

    参考资料:

    A quick complete tutorial to save and restore Tensorflow models

    在该参考资料中,你还可以看到多种保存和使用tensor的方式。另外除了保存模型以外,还有 tf.summary.FileWriter

    train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
    

    使用它可以让你通过tensorbord 查看训练和运行情况。

    相关文章

      网友评论

      本文标题:TensorFlow 模型保存和恢复示例

      本文链接:https://www.haomeiwen.com/subject/zzdjfxtx.html