美文网首页
C++ opencv-3.4.1 调用tensorflow训练好

C++ opencv-3.4.1 调用tensorflow训练好

作者: yanghedada | 来源:发表于2019-03-11 20:29 被阅读0次

本文使用tensorflow的目标检测model zoo的ssd_mobilenet_v1_coco进行目标检测测试。
参考了一些博客。
首先需要下载TensorFlow detection_model_zoo
中的ssd_mobilenet_v1_coco。在文件中有frozen_inference_graph.pb和一个graph.pbtxt文件。但是这个文件graph.pbtxt C++没法使用,所以需要用这个ssd_mobilenet_v1_coco.pbtxt代替。
没训练任何模型直接使用ssd_mobilenet_v1_coco检测80类。

#include<opencv2\opencv.hpp>
#include<opencv2\dnn.hpp>
#include <iostream>

using namespace std;
using namespace cv;

const size_t inWidth = 300;
const size_t inHeight = 300;
const float WHRatio = inWidth / (float)inHeight;
const char* classNames[] = { "background", "person", "bicycle", "car", 
                            "motorcycle", "airplane", "bus", "train", 
                            "truck", "boat", "traffic light", "fire hydrant", 
                            "stop sign", "parking meter", "bench", "bird", "cat", 
                            "dog", "horse", "sheep", "cow", "elephant", "bear", 
                            "zebra", "giraffe", "backpack", "umbrella", "handbag", 
                            "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", 
                            "kite", "baseball bat", "baseball glove", "skateboard", 
                            "surfboard", "tennis racket", "bottle", "wine glass", "cup", 
                            "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", 
                            "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", 
                            "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
                            "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", 
                            "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", 
                            "teddy bear", "hair drier", "toothbrush" };//这个需要根据训练的类别定义

int main() {

    Mat frame = cv::imread("1.jpg");
    Size frame_size = frame.size();
    imshow("image0", frame);

    String weights = "frozen_inference_graph.pb";
    String prototxt = "ssd_mobilenet_v1_coco.pbtxt";
    dnn::Net net = cv::dnn::readNetFromTensorflow(weights, prototxt);

    Size cropSize;
    if (frame_size.width / (float)frame_size.height > WHRatio)
    {
        cropSize = Size(static_cast<int>(frame_size.height * WHRatio),
            frame_size.height);
    }
    else
    {
        cropSize = Size(frame_size.width,
            static_cast<int>(frame_size.width / WHRatio));
    }

    Rect crop(Point((frame_size.width - cropSize.width) / 2,
        (frame_size.height - cropSize.height) / 2),
        cropSize);


    cv::Mat blob = cv::dnn::blobFromImage(frame, 1. / 255, Size(300, 300));
    cout << "blob size: " << blob.size << endl;

    net.setInput(blob);
    Mat output = net.forward();
    cout << "output size: " << output.size << endl;

    Mat detectionMat(output.size[2], output.size[3], CV_32F, output.ptr<float>());

    frame = frame(crop);
    float confidenceThreshold = 0.20;
    for (int i = 0; i < detectionMat.rows; i++)
    {
        float confidence = detectionMat.at<float>(i, 2);

        if (confidence > confidenceThreshold)
        {
            size_t objectClass = (size_t)(detectionMat.at<float>(i, 1));

            int xLeftBottom = static_cast<int>(detectionMat.at<float>(i, 3) * frame.cols);
            int yLeftBottom = static_cast<int>(detectionMat.at<float>(i, 4) * frame.rows);
            int xRightTop = static_cast<int>(detectionMat.at<float>(i, 5) * frame.cols);
            int yRightTop = static_cast<int>(detectionMat.at<float>(i, 6) * frame.rows);

            ostringstream ss;
            ss << confidence;
            String conf(ss.str());

            Rect object((int)xLeftBottom, (int)yLeftBottom,
                (int)(xRightTop - xLeftBottom),
                (int)(yRightTop - yLeftBottom));

            rectangle(frame, object, Scalar(0, 255, 0), 2);
            String label = String(classNames[objectClass]) + ": " + conf;
            int baseLine = 0;
            Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
            rectangle(frame, Rect(Point(xLeftBottom, yLeftBottom - labelSize.height),
                Size(labelSize.width, labelSize.height + baseLine)),
                Scalar(0, 255, 0), CV_FILLED);
            putText(frame, label, Point(xLeftBottom, yLeftBottom),
                FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0));
        }
    }
    imshow("image", frame);
    waitKey(0);
    return 0;
}

结果

参考:

基于opencv dnn模块 的caffe模型的调用
OpenCV调用TensorFlow预训练模型
OpenCV的dnn模块调用TesorFlow训练的MoblieNet模型

相关文章

网友评论

      本文标题:C++ opencv-3.4.1 调用tensorflow训练好

      本文链接:https://www.haomeiwen.com/subject/jrripqtx.html