简介
由于生产环境使用windows、C++,而tensorflow模型训练使用python更为方便,因此存在需求:在windows环境使用tensorflow的c++接口载入训练好的tensorflow模型,并进行测试。类似的文档比较缺乏,并且由于tf本身一直在完善,相比现有的博客各个步骤都有进一步的简化,这里针对1.2.0版本梳理对应的最简单的一种流程:
- 利用tensorflow的python API定义、训练自己的模型
- 利用tensorflow的python API保存模型,并进一步将模型中的变量都转化为常量,通过这样“freeze graph”使得模型导出为一个文件,便于c++调用
- 编译tensorflow的源码来使用tensorflow的c++接口
- 在tensorflow的tutorrials Image Recognition 的基础上修改代码,利用模型进行测试。
利用tf的python API训练模型
这部分属于tensorflow的基础,官方文档getting started有相当详细的介绍和描述,在此不做赘述。值得注意的是tf的命名方式,在python代码中的变量名和在tf的graph中的变量名是两个概念,因此至少针对输入输出要定义tf的graph中的变量名,定义变量名的语法类似loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
。此外也可以利用tf.name_scope
来规划命名。
导出tf模型并freeze graph
这部分有官方工具的代码freeze_graph.py,对应的博客也很多。这里我推荐博客TensorFlow: How to freeze a model and serve it with a python API。
freeze graph就是把原本的图中的变量(卷积核、偏置)等都使用训练好的模型中的值来代替,变成常量。frozen graph的意义在于(freeze_graph.py的注释)
It's useful to do this when we need to load a single file in C++, especially in environments like mobile or embedded where we may not have access to the RestoreTensor ops and file loading calls that they rely on.
推荐的主要原因在于博客中使用方法saver = tf.train.Saver();last_chkp = saver.save(sess, 'results/graph.chkp')
是最为简单的保存模型的方法,同时博客提供了freeze graph的代码,核心采用graph_util.convert_variables_to_constants
方法来进行freeze graph,使得不需要使用官方工具freeze_graph.py。对应freeze_graph的代码引用如下(其中注意到write使用参数‘wb'写为二进制):
import os, argparse
import tensorflow as tf
from tensorflow.python.framework import graph_util
dir = os.path.dirname(os.path.realpath(__file__))
def freeze_graph(model_folder):
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + "/frozen_model.pb"
# Before exporting our graph, we need to precise what is our output node
# This is how TF decides what part of the Graph he has to keep and what part it can dump
# NOTE: this variable is plural, because you can have multiple output nodes
output_node_names = "Accuracy/predictions"
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
input_graph_def, # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_folder", type=str, help="Model folder to export")
args = parser.parse_args()
freeze_graph(args.model_folder)
编译源码来使用tf的c++ API
编译源码的方式官方有文档Installing TensorFlow from Sources,其中有段:
We don't officially support building TensorFlow on Windows; however, you may try to build TensorFlow on Windows if you don't mind using the highly experimental Bazel on Windows or TensorFlow CMake build.
在两种方案中,我选择采用cmake,理由是相对来说环境配置更为容易,但可能使用google自己的bazel相对支持度更高。
参考官方readme一步一步来,值得注意的有两点,一个是git clone的时候推荐git对应的稳定版本的分支(直接master可能会有编译错误和未知bug);另一个是要用命令行进行编译,直接采用vs2015 IDE进行编译会出错C1060,原因应该是默认的编译器调用的不是native 64位的toolset,如何设置使得能够使用IDE直接编译调试的方法还没有找到。
相比于官方的项目tf_tutorials_example_trainer.vcxproj,更有参考意义的项目是tf_label_image_example.vcxproj,对应的详尽官方教程Image Recognition,这个教程使用inception模型来进行识别,对应运行时可能需要修改图片和文件的路径才能正确输出结果。
修改代码实现自己的模型
教程源码提供了模型读取,图片读取,Label读取等核心步骤,修改对应代码进行编译能够很容易上手完成任务,下面贴一下保存图片的代码,总体是读取图片的逆向过程:
// Given an output tensor with 4d, reduce dim and output jpg image
Status SaveTensorToImageFile(const string& file_name, const Tensor* out_tensor) {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
auto output_image_data = tensorflow::ops::Reshape(root, *out_tensor, { 256, 256, 3 });
auto output_image_data_cast = tensorflow::ops::Cast(root, output_image_data, tensorflow::DT_UINT8);
auto output_image = tensorflow::ops::EncodeJpeg(root, output_image_data_cast);
auto output_op = tensorflow::ops::WriteFile(root.WithOpName("output/image"), file_name/*"D:/tf_face/trained_model_fast/output.jpg"*/, output_image);
string output_name = "output/image";
// This runs the GraphDef network definition that we've just constructed, and
// returns the results in the output tensor.
tensorflow::GraphDef graph;
TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
std::unique_ptr<tensorflow::Session> session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_RETURN_IF_ERROR(session->Create(graph));
Status writeResult = session->Run({}, {}, { output_name }, {});
return writeResult;
}
代码中图片的尺寸可以自行定义,其中要注意的是c++中session->Run函数传入的参数无论是ops或是Tensor都是要使用tf定义的名字root.WithOpName("output/image")
而不是c++代码中定义的局部变量output_op
,以上在tf的CPU版本上流程走通。
参考链接
Tensorflow C++ API调用预训练模型和生产环境编译 (unix )
TensorFlow: How to freeze a model and serve it with a python API
TensorFlow CMake build
Tensorflow Tutorial Image Recognition
网友评论