美文网首页随笔-生活工作点滴
tensorflow c++ 读入pb文件

tensorflow c++ 读入pb文件

作者: ryan_ren | 来源:发表于2019-07-09 21:34 被阅读8次

    最近因为项目原因,开始接触c++ tensorflow,感觉c++的比python复杂太多了。不过好在不需要使用c++来训练模型,而是python端训练好,把模型文件保存下来,c++端导入模型文件,然后读入图片进行预测。
    代码分为以下几个部分

    1. 创建session,同时设置gpu选项
    Session* session;
    SessionOptions options; 
    options.config.mutable_gpu_options()->set_visible_device_list(gpus); //设置使用的gpu
    options.config.mutable_gpu_options()->set_allow_growth(true); //设置GPU内存自动增长
    TF_CHECK_OK(NewSession(options, &session));//创建新会话Session
    
    1. 从pb文件中读取模型,将模型导入session
    GraphDef graphdef; //Graph Definition for current model
    TF_CHECK_OK(ReadBinaryProto(Env::Default(), model_path, &graphdef)); //从pb文件中读取图模型;
    TF_CHECK_OK(session->Create(graphdef)); //将模型导入会话Session中;
    
    1. 开始测试
    Mat image = imread(images_address + image_name);
    mat2Tensor(image, input);    //这个函数是将Mat转为Tensor
    std::vector<std::pair<std::string, tensorflow::Tensor> > in;    //模型输入数据
    std::vector<std::string> out;             //输出数据名称
    std::vector<tensorflow::Tensor> outputs;      //输出数据存放的数组
    in.push_back(pair<string, Tensor>(input_name, input));       //input_name为输入数据的名称,这个是将输入数据与名称一起放到数组中
    out.push_back(output_name); //将输出数据的名称放到数组中
    TF_CHECK_OK(session->Run(in, out, {}, &outputs));   //运行模型,得到输出
    
    1. 完整代码
    #include "tensorflow/core/framework/graph.pb.h"
    #include "tensorflow/core/public/session.h"
    #include "tensorflow/core/framework/tensor.h"
    #include "tensorflow/core/graph/default_device.h"
    
    #include <string>
    #include <stdlib.h>
    #include <iostream>
    #include <stdio.h>
    #include <ctime>
    
    #include <opencv2/highgui.hpp>
    #include <opencv2/core.hpp>
    #include <opencv2/opencv.hpp>
    #include <opencv2/imgproc/imgproc.hpp>
    
    //linux
    #include <sys/types.h>
    #include <dirent.h>
    #include <sys/stat.h>
    
    using namespace cv;
    using namespace tensorflow;
    using namespace std;
    
    const int IMAGE_SIZE = 512;
    const int CLASS = 36;
    
    void mat2Tensor(Mat &image, Tensor &t) {
        resize(image, image, Size(IMAGE_SIZE, IMAGE_SIZE));
    
        cvtColor(image, image, cv::COLOR_RGB2BGR);
        float *tensor_data_ptr = t.flat<float>().data();
        cv::Mat fake_mat(image.rows, image.cols, CV_32FC(image.channels()), tensor_data_ptr);
        image.convertTo(fake_mat, CV_32FC(image.channels()));
    }
    
    void tensor2Mat(Tensor &t, Mat &image) {
        int *p = t.flat<int>().data();
        image = Mat(IMAGE_SIZE, IMAGE_SIZE, CV_32SC1, p);
        image.convertTo(image, CV_8UC1);
    }
    void solve() {
        string model_path = "../modules/20190707.pb";
    
        string input_name = "inputs/X:0";
    
        string output_name = "preds:0";
    
        string images_address = "../Images/";
    
        string output_address = "../outputs/";
    
        string gpus = "0";
    
        /*--------------------------------创建session------------------------------*/
        Session* session;
        SessionOptions options;
        
        options.config.mutable_gpu_options()->set_visible_device_list(gpus);
    
        options.config.mutable_gpu_options()->set_allow_growth(true);
        
    
        TF_CHECK_OK(NewSession(options, &session));//创建新会话Session
        
    
        /*--------------------------------从pb文件中读取模型--------------------------------*/
        
        GraphDef graphdef; //Graph Definition for current model
        TF_CHECK_OK(ReadBinaryProto(Env::Default(), model_path, &graphdef)); //从pb文件中读取图模型;
    
        TF_CHECK_OK(session->Create(graphdef)); //将模型导入会话Session中;
    
        std::cout << "<----Successfully created session and load graph.------->" << std::endl;
        Tensor input(DT_FLOAT, TensorShape({ 1, IMAGE_SIZE, IMAGE_SIZE, 3 }));
    
        /*--------------------------------读取目录下的图片--------------------------------*/
    
        vector<string> images;  //图片名称
    
        for(int i = 4800; i < 5550; i++) {
            char tmp[20];
            sprintf(tmp, "img_%06d.png",i*4);
            // cout << tmp << endl;
            images.push_back(tmp);
        }
    
        cout << "load data success" << endl;
    
        /*--------------------------------开始测试--------------------------------*/
        cout << "start run" << endl;
        double total_time = 0;
        Mat lut(1, 256, CV_8UC3, lutData);
        for(string image_name : images) {
            Mat image = imread(images_address + image_name);
            mat2Tensor(image, input);
            std::vector<std::pair<std::string, tensorflow::Tensor> > in;
            std::vector<std::string> out;
            std::vector<tensorflow::Tensor> outputs;
            in.push_back(pair<string, Tensor>(input_name, input));
            out.push_back(output_name);     
            TF_CHECK_OK(session->Run(in, out, {}, &outputs));
            Mat res (IMAGE_SIZE, IMAGE_SIZE, CV_8UC1);
            tensor2Mat(outputs[0], res);
            cv::imwrite(output_address + image_name, res);
        }
    }
    int main()
    {
        solve();
        return 0;
    }
    
    

    相关文章

      网友评论

        本文标题:tensorflow c++ 读入pb文件

        本文链接:https://www.haomeiwen.com/subject/keeyhctx.html