美文网首页LSTM深入理解tensorflowTensorflow
TensorFlow 模型保存和恢复示例

TensorFlow 模型保存和恢复示例

作者: 祝威廉 | 来源:发表于2017-05-31 17:14 被阅读2422次

前言

在之前一篇文章里:使用CNN+ Auto-Encoder 实现无监督Sentence Embedding (代码基于Tensorflow),训练完成后,encode的参数也就被训练好了,这个时候我们利用这些参数对数据进行编码处理,从而得到一个向量。

保存模型

如果回忆下,上次的模型基本是这样的:

Input(段落) -> encoder -> encoder -> decoder -> decoder -> lost function (consine夹角)

我需要用到的是第二个encoder,在Tensorflow里,所有的都是Tensor,因此给定输入,就可以通过tensor给出输出。训练的过程中,涉及到第二个encoder的代码如下:

....
flattened = tf.reshape(conv_out, [-1, 51 * 128]) if USE_CNN else tf.reshape(input_x,
                                                                                [-1, SEQUENCE_LENGTH * VOCAB_SIZE])

encoder_op = encoder(flattened)
....

我们真个训练过程其实是在tunning encoder的参数。现在我需要把encoder_op保留下来,供下次使用,这可以通过add_collection方法

tf.add_to_collection('encoder_op', encoder_op)

sess.run(tf.global_variables_initializer()) 之后,我们获取Saver对象:

saver = tf.train.Saver()

然后在迭代的过程中,比如每迭代五次就保存一次模型:

if i %5 = 0: 
    saver.save(sess, MODEL_SAVE_DIR)

恢复模型

sess = tf.Session()
## 这里是恢复graph
saver = tf.train.import_meta_graph(MODEL_SAVE_DIR + '/' + MODEL_NAME + '.meta')
## 这里是恢复各个权重参数
saver.restore(sess, tf.train.latest_checkpoint(MODEL_SAVE_DIR))


sess.run(tf.global_variables_initializer())
## 获取输入的tensor
input_x = tf.get_default_graph().get_tensor_by_name("input_x:0")
......

x_in = result1[0:SEQUENCE_LENGTH]
## 获取到encoder_op
encoder_op = tf.get_collection("encoder_op")[0]
## 给定数据,运行encoder_op 
s = sess.run(encoder_op, feed_dict={input_x: [x_in]})

具体的解释已经在代码中提及。这样我们就可以利用encoder_op对新数据进行编码了。

完整的恢复模型参看:tensorflow_restore.py

额外的话

参考资料:

A quick complete tutorial to save and restore Tensorflow models

在该参考资料中,你还可以看到多种保存和使用tensor的方式。另外除了保存模型以外,还有 tf.summary.FileWriter

train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

使用它可以让你通过tensorbord 查看训练和运行情况。

相关文章

网友评论

本文标题:TensorFlow 模型保存和恢复示例

本文链接:https://www.haomeiwen.com/subject/zzdjfxtx.html