前言
在之前一篇文章里:使用CNN+ Auto-Encoder 实现无监督Sentence Embedding (代码基于Tensorflow),训练完成后,encode的参数也就被训练好了,这个时候我们利用这些参数对数据进行编码处理,从而得到一个向量。
保存模型
如果回忆下,上次的模型基本是这样的:
Input(段落) -> encoder -> encoder -> decoder -> decoder -> lost function (consine夹角)
我需要用到的是第二个encoder,在Tensorflow里,所有的都是Tensor,因此给定输入,就可以通过tensor给出输出。训练的过程中,涉及到第二个encoder的代码如下:
....
flattened = tf.reshape(conv_out, [-1, 51 * 128]) if USE_CNN else tf.reshape(input_x,
[-1, SEQUENCE_LENGTH * VOCAB_SIZE])
encoder_op = encoder(flattened)
....
我们真个训练过程其实是在tunning encoder的参数。现在我需要把encoder_op保留下来,供下次使用,这可以通过add_collection方法
tf.add_to_collection('encoder_op', encoder_op)
在sess.run(tf.global_variables_initializer())
之后,我们获取Saver对象:
saver = tf.train.Saver()
然后在迭代的过程中,比如每迭代五次就保存一次模型:
if i %5 = 0:
saver.save(sess, MODEL_SAVE_DIR)
恢复模型
sess = tf.Session()
## 这里是恢复graph
saver = tf.train.import_meta_graph(MODEL_SAVE_DIR + '/' + MODEL_NAME + '.meta')
## 这里是恢复各个权重参数
saver.restore(sess, tf.train.latest_checkpoint(MODEL_SAVE_DIR))
sess.run(tf.global_variables_initializer())
## 获取输入的tensor
input_x = tf.get_default_graph().get_tensor_by_name("input_x:0")
......
x_in = result1[0:SEQUENCE_LENGTH]
## 获取到encoder_op
encoder_op = tf.get_collection("encoder_op")[0]
## 给定数据,运行encoder_op
s = sess.run(encoder_op, feed_dict={input_x: [x_in]})
具体的解释已经在代码中提及。这样我们就可以利用encoder_op对新数据进行编码了。
完整的恢复模型参看:tensorflow_restore.py
额外的话
参考资料:
A quick complete tutorial to save and restore Tensorflow models
在该参考资料中,你还可以看到多种保存和使用tensor的方式。另外除了保存模型以外,还有 tf.summary.FileWriter
train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
使用它可以让你通过tensorbord 查看训练和运行情况。
网友评论
欢迎订阅《数据淘金》https://toutiao.io/subject/120437