美文网首页机器学习TF/coffedeep learningTensorFlow Lite
TensorFlow Lite学习笔记2:生成TFLite模型文

TensorFlow Lite学习笔记2:生成TFLite模型文

作者: 城市守望者 | 来源:发表于2017-11-21 14:52 被阅读6882次

    简介

    在桌面PC或是服务器上使用TensorFlow训练出来的模型文件,不能直接用在TFLite上运行,需要使用离线工具先转成.tflite文件。笔者发现官方文档中很多细节介绍的都不太明确,在使用过程中需要不断尝试。我把自己的尝试过的步骤分享出来,希望能帮助大家节省时间。

    具体说来,tflite文件的生成大致分为3步:

    1. 在算法训练的脚本中保存图模型文件(GraphDef)和变量文件(CheckPoint)。

    2. 利用freeze_graph工具生成frozen的graphdef文件。

    3. 利用toco工具,生成最终的tflite文件。

    图1. 生成tflite文件的整个流程示意图

    第1步:导出图模型文件和变量文件

    在你的算法的训练或推理任务的脚本中,利用tensorflow.train中的write_graph和saver API来导出GraphDef及Checkpoint文件。

    图2. TensorFlow中导出GraphDef文件和Checkpoint文件

    其中,tf.train.write_graph一行将导出模型的GraphDef文件,实际上保存了训练的神经网络的结构图信息。存储格式为protobuffer,所以文件名后缀为pb。

    图3. 导出的GraphDef文件

    tf.train.saver.save一行导出的是模型的变量文件,实际上保存了整个图中所有变量目前的取值。

    图4. 导出的checkpoint文件

    如图4所示,实际上产生了4个文件。在后续步骤中需要用到的是nsfw_model.ckpt.data-00000-of-00001这个文件,保存了当前神经网络各参数的取值。

    第2步:生成frozen的graphdef文件

    在此步骤中,使用Tensorflow源代码中自带的freeze_graph工具,生成一个frozen的GraphDef文件。

    bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/data/deep_learning/nsfw/model/nsfw-graph.pb --input_checkpoint=/data/deep_learning/nsfw/model/nsfw_model.ckpt --input_binary=true --output_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb --output_node_names=predictions

    这里有两个地方容易搞错。第一个地方,input_checkpoint参数实际上用到的文件应该是nsfw_model.ckpt.data-00000-of-00001,但是在指定文件名的时候只需要指定nsfw_model.ckpt即可。第二个地方,是output_node_names参数,此处指定的是神经网络图中的输出节点的名字,是在训练阶段的Python脚本中定义的。如下图所示,在定义网络结构时,输出节点的名称为"predictions"。则最终output_node_names需要指定为“predictions”。

    图5. output_node_names参数取值与网络模型定义时的名字要对应

    当然,也可以利用summarize_graph打印出模型的输入和输出节点,如:

    bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/data/deep_learning/nsfw/model/frozen_nsfw.pb

    图6. 输入节点名称为input 图7. 输出节点名称为predictions

    第3步:生成最终的tflite文件

    在此步骤中,使用Tensorflow源代码中自带的toco工具,生成一个可供TensorFlow Lite框架使用tflite文件。其中input_arrays和output_arrays的名称需要与定义网络类型时取的名称保持一致。

    bazel run --config=opt tensorflow/contrib/lite/toco:toco --input_file=/data/deep_learning/nsfw/model/frozen_nsfw.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=/data/deep_learning/nsfw/model/nsfw.lite --inference_type=FLOAT --input_type=FLOAT --input_arrays=input --output_arrays=predictions --input_shapes=1,224,224,3

    生成的nsfw.lite文件即可用于TensorFlow Lite应用。

    相关文章

      网友评论

      • JuiYang:你好,我想请问下如何将多个.pd文件打成一个.lite文件
        bazel-bin/tensorflow/contrib/lite/toco/toco
        --input_file=/tmp/frozen_graph.pb
        --input_format=TENSORFLOW_GRAPHDEF
        --output_format=TFLITE
        --output_file=/tmp/mobilenet_v1_224.tflite
        --inference_type=FLOAT
        --input_arrays=input
        --output_arrays=MobilenetV1/Predictions/Reshape_1
        --output_arrays=MobilenetV1/Logists
        --input_shares=1,224,224,3

        这样的命令在android 上报错,说是模型输出只有一个。我想会不会是最后一个输出output_arrays覆盖了第一个输出output_arrays。如果是这种情况应该怎么解决,谢谢
      • a如果不曾相遇:麻烦问一下大家,有没有人是在windows 下安装bazel的,求教程:grin:
      • 李保林:用自己的图片重训练mobilenet模型,然后转换成tflite模型后,在demo中使用会报错,楼主是否有遇到过呢?

        Caused by: java.lang.IllegalArgumentException: Shape of output target [1, 49] does not match with the shape of the Tensor [1, 1001].
        at org.tensorflow.lite.Tensor.copyTo(Tensor.java:44)
        at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:139)
        at org.tensorflow.lite.Interpreter.run(Interpreter.java:107)
        at com.android.gallery3d.tensorflow.lite.ImageClassifierQuantizedMobileNet.runInference(ImageClassifierQuantizedMobileNet.java:95)
        at com.android.gallery3d.tensorflow.lite.ImageClassifier.classifyFrame(ImageClassifier.java:128)

        原来的tflite模型能识别1001种,而自己训练的49种类别,这样会报错? 1001不知是在代码哪里设下去的,还是跟模型有关系?
      • qumoy:Some of the operators in the model are not supported by the standard TensorFlow Lite runtime .If you have a custom implementation for them you can disable this error with --allow_custom_ops.Here is a list of operators for which you will need custom implementations:FLOOR,RandomUniform.
        大佬我在转换模型的时候遇到了这个问题,忽略之后在android端也遇到这个问题,想请教一下该如何解决?
      • Allence:你好 请问tensorflow lite 项目中怎么集成啊 有教程吗
      • 生如一人:请问下您的bazel安装方式是采用哪一个?我是ubuntu16.04的版本,采用APT安装之后出现了没有设置workspace等问题,如下:
        ERROR: Error evaluating WORKSPACE file
        ERROR: Skipping 'tensorflow/contrib/lite/toco:toco': error loading package 'external': Package 'external' contains errors
        WARNING: Target pattern parsing failed.
        ERROR: error loading package 'external': Package 'external' contains errors
        INFO: Elapsed time: 0.427s
        FAILED: Build did NOT complete successfully (0 packages loaded)
        请问这是什么问题,拜托了!
      • 程序员的天马行空:你好,我想问下,我生成的.ckpt文件4.2m,.pb文件1m都不到,为什么合在一起生成的.lite文件有42m那么大.
      • VoitYa:你好,我git clone tensorflow的项目以后, 没有找见bazel-bin/tensorflow/python/tools/freeze_graph这个目录和文件,freeze那一步进行不下去,这个目录是源码安装tensorflow才会生成吗,我用pip安装tensorflow怎么才能使用freeze这个工具呢?
      • 10befff22aa1:能否请问一下,在使用Tensorflow源代码中自带的toco工具时,您那边是否遇到过问题。
        我使用bazel build时,执行失败了。目前显示的错误是This package requires Visual Stuido 2015 Update 2 or higher。
        我想请问一下,这个执行bazel build的过程中,是否用到Visual Studio,还是说我的错误是在其他的地方。多谢了
        城市守望者:我用的是Ubuntu和Centos,没遇到问题
        城市守望者:@睿姝 这个倒是没试过,建议在windows上直接装个vmware或者virtual box,然后弄个ubuntu的虚拟机
        10befff22aa1:补充一下,我是在Windows系统下使用Cygwin运行的。这是否也会产生问题。
      • a2b0a5b174bf:你好,我再用“bazel build //tensorflow/contrib/lite/toco:toco”的时候,出现如下错误:
        ERROR: /Users/xiaoqiang/6TensorFlowlite/tensorflow-master/tensorflow/contrib/lite/BUILD:193:12: Label '//tensorflow/contrib/lite:downloads/absl/absl/types/optional.h' crosses boundary of subpackage 'tensorflow/contrib/lite/downloads/absl/absl/types' (perhaps you meant to put the colon here: '//tensorflow/contrib/lite/downloads/absl/absl/types:optional.h'?)

        请问你遇到了吗?顺便问一下,你的bazel是什么版本
      • 49ab9b070cf4:你好,我按照你这样的步骤生成了自己训练模型,但是通过tensforflow lite中替换了演示模型后,程序无法进行识别,请问还需要什么设置呢?
        49ab9b070cf4:@城市守望者 请问具体有哪些坑呢?
        49ab9b070cf4:@城市守望者 是的。用的TF Lite DEMO程序换了自己的模型,可以一起技术交流吗?我QQ:178789314。相互学习。
        城市守望者:你用的是TF Lite的demo程序?具体碰到的错误是什么,里面坑还是挺多的。。。

      本文标题:TensorFlow Lite学习笔记2:生成TFLite模型文

      本文链接:https://www.haomeiwen.com/subject/slwdvxtx.html