参考链接:https://www.jianshu.com/p/091415b114e2
https://cloud.tencent.com/developer/ask/188650
由于arm nn官方提供的mnist-tf例程中提供的模型类型是prototxt或者pb文件,所以这里需要把tensorflow保存的ckpt文件转换成pb文件
tensorflow训练生成的ckpt文件包含4个,分别是
1. checkpoint文件,记录了最新的检查点文件
2. model.data文件,是saver.save(sess)保存的结果,记录了所有变量的值
3. model.index文件,暂不明确,待查。恢复模型不必须用到
4. model.meta文件,保存了计算图的结构,没有变量的值
转换方法
-
使用freeze_graph(见第一个参考链接,经过测试发现对于很小的模型lenet5可以成功,但是对于较大的模型,比如这里用到的一个400MB左右的网络,经过测试,会把16GB的内存消耗干净,转换失败=_=)
-
使用convert_variables_to_constants
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
sess = tf.Session()
saver = tf.train.import_meta_graph("meta文件目录")
saver.restore(sess, tf.train.latest_checkpoint("checkpoint文件所在目录"))
graph = tf.get_default_graph()
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['输出tensor名字'])
with tf.gfile.FastGFile('pb文件保存目录', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
网友评论