最近因为项目原因,开始接触c++ tensorflow,感觉c++的比python复杂太多了。不过好在不需要使用c++来训练模型,而是python端训练好,把模型文件保存下来,c++端导入模型文件,然后读入图片进行预测。
代码分为以下几个部分
- 创建session,同时设置gpu选项
Session* session;
SessionOptions options;
options.config.mutable_gpu_options()->set_visible_device_list(gpus); //设置使用的gpu
options.config.mutable_gpu_options()->set_allow_growth(true); //设置GPU内存自动增长
TF_CHECK_OK(NewSession(options, &session));//创建新会话Session
- 从pb文件中读取模型,将模型导入session
GraphDef graphdef; //Graph Definition for current model
TF_CHECK_OK(ReadBinaryProto(Env::Default(), model_path, &graphdef)); //从pb文件中读取图模型;
TF_CHECK_OK(session->Create(graphdef)); //将模型导入会话Session中;
- 开始测试
Mat image = imread(images_address + image_name);
mat2Tensor(image, input); //这个函数是将Mat转为Tensor
std::vector<std::pair<std::string, tensorflow::Tensor> > in; //模型输入数据
std::vector<std::string> out; //输出数据名称
std::vector<tensorflow::Tensor> outputs; //输出数据存放的数组
in.push_back(pair<string, Tensor>(input_name, input)); //input_name为输入数据的名称,这个是将输入数据与名称一起放到数组中
out.push_back(output_name); //将输出数据的名称放到数组中
TF_CHECK_OK(session->Run(in, out, {}, &outputs)); //运行模型,得到输出
- 完整代码
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include <string>
#include <stdlib.h>
#include <iostream>
#include <stdio.h>
#include <ctime>
#include <opencv2/highgui.hpp>
#include <opencv2/core.hpp>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/imgproc.hpp>
//linux
#include <sys/types.h>
#include <dirent.h>
#include <sys/stat.h>
using namespace cv;
using namespace tensorflow;
using namespace std;
const int IMAGE_SIZE = 512;
const int CLASS = 36;
void mat2Tensor(Mat &image, Tensor &t) {
resize(image, image, Size(IMAGE_SIZE, IMAGE_SIZE));
cvtColor(image, image, cv::COLOR_RGB2BGR);
float *tensor_data_ptr = t.flat<float>().data();
cv::Mat fake_mat(image.rows, image.cols, CV_32FC(image.channels()), tensor_data_ptr);
image.convertTo(fake_mat, CV_32FC(image.channels()));
}
void tensor2Mat(Tensor &t, Mat &image) {
int *p = t.flat<int>().data();
image = Mat(IMAGE_SIZE, IMAGE_SIZE, CV_32SC1, p);
image.convertTo(image, CV_8UC1);
}
void solve() {
string model_path = "../modules/20190707.pb";
string input_name = "inputs/X:0";
string output_name = "preds:0";
string images_address = "../Images/";
string output_address = "../outputs/";
string gpus = "0";
/*--------------------------------创建session------------------------------*/
Session* session;
SessionOptions options;
options.config.mutable_gpu_options()->set_visible_device_list(gpus);
options.config.mutable_gpu_options()->set_allow_growth(true);
TF_CHECK_OK(NewSession(options, &session));//创建新会话Session
/*--------------------------------从pb文件中读取模型--------------------------------*/
GraphDef graphdef; //Graph Definition for current model
TF_CHECK_OK(ReadBinaryProto(Env::Default(), model_path, &graphdef)); //从pb文件中读取图模型;
TF_CHECK_OK(session->Create(graphdef)); //将模型导入会话Session中;
std::cout << "<----Successfully created session and load graph.------->" << std::endl;
Tensor input(DT_FLOAT, TensorShape({ 1, IMAGE_SIZE, IMAGE_SIZE, 3 }));
/*--------------------------------读取目录下的图片--------------------------------*/
vector<string> images; //图片名称
for(int i = 4800; i < 5550; i++) {
char tmp[20];
sprintf(tmp, "img_%06d.png",i*4);
// cout << tmp << endl;
images.push_back(tmp);
}
cout << "load data success" << endl;
/*--------------------------------开始测试--------------------------------*/
cout << "start run" << endl;
double total_time = 0;
Mat lut(1, 256, CV_8UC3, lutData);
for(string image_name : images) {
Mat image = imread(images_address + image_name);
mat2Tensor(image, input);
std::vector<std::pair<std::string, tensorflow::Tensor> > in;
std::vector<std::string> out;
std::vector<tensorflow::Tensor> outputs;
in.push_back(pair<string, Tensor>(input_name, input));
out.push_back(output_name);
TF_CHECK_OK(session->Run(in, out, {}, &outputs));
Mat res (IMAGE_SIZE, IMAGE_SIZE, CV_8UC1);
tensor2Mat(outputs[0], res);
cv::imwrite(output_address + image_name, res);
}
}
int main()
{
solve();
return 0;
}
网友评论