美文网首页
Tensorflow模型文件固化

Tensorflow模型文件固化

作者: AIPlayer | 来源:发表于2019-08-09 19:45 被阅读0次

    当我们拿到别人训练后保存的模型文件后,如果需要通过C++接口部署模型的话,一般情况下都需要将模型固化并保存为pb格式。友好的Tensorflow提供了相关的固化命令脚本,下面以 .meta 格式的固化为例说明使用的方式:

    一、文件内容

    用 tf.train.Saver.save() 方式保存下来的checkpoint会产生四个文件:

    • checkpoint
      记录了部分已存储和最近存储的模型:
    model_checkpoint_path: "model.ckpt-1"
    all_model_checkpoint_paths: "model.ckpt-1"
    ...
    
    • model.ckpt.data-00000-of-00001
      保存了模型的所有变量的值,TensorBundle集合。

    • model.ckpt.index
      string-string的映射表,映射表的key值为tensor名,value为serialized BundleEntryProto,每个BundleEntryProto表述了tensor的metadata。

    • model.ckpt.meta
      保存了graph结构,包括 GraphDef,SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值。

    二、固化命令

    python tensorflow/python/tools/freeze_graph.py \
                --input_meta_graph=model.ckpt.meta \
                --input_checkpoint=model.ckpt \
                --output_graph=frozen_graph_meta.pb \
                --output_node_name=embeddings \
                --input_binary=True
    
    • 问题一、可能遇到的错误
    UnicodeDecoderError: 'utf-8' codec can't decode byte 0xd8 in position 1: invalid continuation byte
    

    解决:传入参数 --input_binary=True

    • 问题二、当我们手上只有别人训练好的模型文件时,如何确定输入参数中的--output_node_name呢?

    最直观的方式是使用TensorBoard查看图结构,步骤如下:
    (1)从.meta文件生成Tensorboard所需的log,可以通过以下代码生成,运行脚本后在log目录下生成events.out.tfevents.xxx文件。

    import tensorflow as tf
    import os
    def write_graph_log(meta_file, log_dir):
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)        
        g = tf.Graph()
        with g.as_default() as g:
            tf.train.import_meta_graph(meta_file)       
        with tf.Session(graph=g) as sess:
            tf.summary.FileWriter(logdir=log_dir, graph=g)        
    if __name__ == '__main__':
        write_graph_log('model.ckpt.meta', './log/')
    

    (2) 在Windows下通过cmd命令行启动Tensorboard

    cd model_dir               # 进入模型文件所在的目录
    tensorboard --logdir=log   # 启动tensorboard,指定log目录
    

    (3)浏览器打开tensorflow显示的网址(一般为http://127.0.0.1:6006),通过可视化的图结构可以清楚地看到输入和输出节点的名字。

    相关文章

      网友评论

          本文标题:Tensorflow模型文件固化

          本文链接:https://www.haomeiwen.com/subject/ylifjctx.html