美文网首页
opencv中SVM训练mnist手写体

opencv中SVM训练mnist手写体

作者: 一路向后 | 来源:发表于2021-12-19 17:05 被阅读0次

    1.源码实现

    #include <iostream>
    #include <string>
    #include <fstream>
    #include <opencv2/opencv.hpp>
    #include <opencv2/ml/ml.hpp>
    #include <opencv2/highgui/highgui.hpp>
    
    using namespace std;
    using namespace cv;
    
    //小端存储转换
    unsigned int reverseInt(unsigned int i)
    {
        unsigned char c1, c2, c3, c4;
    
        c1 = i & 0xff;
        c2 = (i >> 8) & 0xff;
        c3 = (i >> 16) & 0xff;
        c4 = (i >> 24) & 0xff;
    
        return ((unsigned int)c1 << 24) + ((unsigned int)c2 << 16) + ((unsigned int)c3 << 8) + c4;
    }
    
    Mat read_mnist_image(const string fileName)
    {
        unsigned int magic_number = 0;
        unsigned int number_of_images = 0;
        unsigned int n_rows = 0;
        unsigned int n_cols = 0;
        Mat DataMat;
    
        ifstream file(fileName, ios::binary);
        if(file.is_open())
        {
            cout << "成功打开图像集..." << endl;
    
            file.read((char *)&magic_number, sizeof(magic_number));         //幻数(文件格式)
            file.read((char *)&number_of_images, sizeof(number_of_images));     //图片总数
            file.read((char *)&n_rows, sizeof(n_rows));             //每个图像的行数
            file.read((char *)&n_cols, sizeof(n_cols));             //每个图像的列数
    
            magic_number = reverseInt(magic_number);
            number_of_images = reverseInt(number_of_images);
            n_rows = reverseInt(n_rows);
            n_cols = reverseInt(n_cols);
    
            cout << "幻数(文件格式): " << magic_number << endl;
            cout << "图片总数: " << number_of_images << endl;
            cout << "每个图像的行数: " << n_rows << endl;
            cout << "每个图像的列数: " << n_cols << endl;
    
            cout << "开始读取Image数据..." << endl;
    
            DataMat = Mat::zeros(number_of_images, n_rows*n_cols, CV_32FC1);
    
            for(int i=0; i<number_of_images; i++)
            {
                for(int j=0; j<n_rows*n_cols; j++)
                {
                    unsigned char temp = 0;
    
                    file.read((char *)&temp, sizeof(temp));
    
                    float pixel = float(temp);
    
                    DataMat.at<float>(i, j) = pixel;
                }
            }
    
            cout << "读取Image数据完毕..." << endl;
        }
    
        file.close();
    
        return DataMat;
    }
    
    Mat read_mnist_label(const string fileName)
    {
        unsigned int magic_number = 0;
        unsigned int number_of_items = 0;
        Mat LabelMat;
    
        ifstream file(fileName, ios::binary);
        if(file.is_open())
        {
            cout << "成功打开标签集..." << endl;
    
            file.read((char *)&magic_number, sizeof(magic_number));         //幻数(文件格式)
            file.read((char *)&number_of_items, sizeof(number_of_items));       //标签总数
    
            magic_number = reverseInt(magic_number);
            number_of_items = reverseInt(number_of_items);
    
            cout << "幻数(文件格式): " << magic_number << endl;
            cout << "标签总数: " << number_of_items << endl;
    
            cout << "开始读取Label数据..." << endl;
    
            LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
    
            for(int i=0; i<number_of_items; i++)
            {
                unsigned char temp = 0;
    
                file.read((char *)&temp, sizeof(temp));
    
                LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
            }
    
            cout << "读取Label数据完毕..." << endl;
        }
    
        file.close();
    
        return LabelMat;
    }
    
    void get_rows_sub_mat_a(Mat &a, Mat &b, int start, int end, int step)
    {
        int rownum = (end - start) / step;
    
        b = Mat::zeros(rownum, a.cols, CV_32FC1);
    
        for(int i=0; i<rownum; i++)
        {
            for(int j=0; j<a.cols; j++)
            {
                float pixel = a.at<float>(start+step*i, j);
    
                b.at<float>(i, j) = pixel;
            }
        }
    }
    
    void get_rows_sub_mat_b(Mat &a, Mat &b, int start, int end, int step)
    {
        int rownum = (end - start) / step;
    
        b = Mat::zeros(rownum, a.cols, CV_32FC1);
    
        for(int i=0; i<rownum; i++)
        {
            unsigned int label = a.at<unsigned int>(start+step*i, 0);
    
            b.at<float>(i, 0) = (float)label;
    
            //cout << "lable: " << label << endl;
        }
    }
    
    int main()
    {
        CvSVMParams params;
        CvSVM SVM;
        string train_images_path = "./train-images-idx3-ubyte";
        string train_labels_path = "./train-labels-idx1-ubyte";
        string test_images_path = "./t10k-images-idx3-ubyte";
        string test_labels_path = "./t10k-labels-idx1-ubyte";
    
        //set up SVM's parameters
        params.svm_type = CvSVM::C_SVC;
        params.kernel_type = CvSVM::POLY;
        params.gamma = 1.0;
        params.C = 10.0;
        params.nu = 0.5;
        params.degree = 2.10;
        //params.coef0 = 1000.0;
        //params.term_crit = cvTermCriteria(CV_TERMCRIT_EPS, 10000, FLT_EPSILON);
        //params.svm_type = CvSVM::C_SVC;
        //params.kernel_type = CvSVM::LINEAR;
        params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 10000, 1e-6);
    
        //读取标签数据集
        Mat train_labels = read_mnist_label(train_labels_path);
        Mat test_labels = read_mnist_label(test_labels_path);
    
        //读取图像数据集
        Mat train_images = read_mnist_image(train_images_path);
        Mat test_images = read_mnist_image(test_images_path);
    
        Mat train_labels_subs;
        Mat train_images_subs;
        Mat test_labels_subs;
        Mat test_images_subs;
    
        get_rows_sub_mat_b(train_labels, train_labels_subs, 0, 60000, 1);
        get_rows_sub_mat_a(train_images, train_images_subs, 0, 60000, 1);
    
        get_rows_sub_mat_b(test_labels, test_labels_subs, 0, 10000, 1);
        get_rows_sub_mat_a(test_images, test_images_subs, 0, 10000, 1);
    
        SVM.train(train_images_subs, train_labels_subs, Mat(), Mat(), params);
    
        SVM.save("mnist_svm.xml");
    
        //cout << "train end" << endl;
    
        int count = 0;
        for(int i = 0; i < test_images_subs.rows; i++)
        {
            Mat sample = test_images_subs.row(i);
            Mat label;
            float res = SVM.predict(sample);
            int r = 0;
    
            //cout << "res: " << res << " label: " << test_labels_subs.at<float>(i, 0) << endl;
    
            r = std::abs(res - test_labels_subs.at<float>(i, 0)) <= 0.0001 ? 1 : 0;
    
            count += r;
        }
    
        cout << "正确的识别个数 count = " << count << endl;
        cout << "错误率为..." << double(test_images_subs.rows - count) / test_images_subs.rows * 100.0 << "%....\n";
    
        return 0;
    }
    

    2.编译源码

    $ g++ -o test test.cpp -std=c++11 -I/usr/local/include/opencv4 -L/usr/local/lib -lopencv_core -lopencv_highgui -lopencv_imgproc -lopencv_ml -Wl,-rpath=/usr/local/lib
    

    3.运行及其结果

    $ time ./test
    成功打开标签集...
    幻数(文件格式): 2049
    标签总数: 60000
    开始读取Label数据...
    读取Label数据完毕...
    成功打开标签集...
    幻数(文件格式): 2049
    标签总数: 10000
    开始读取Label数据...
    读取Label数据完毕...
    成功打开图像集...
    幻数(文件格式): 2051
    图片总数: 60000
    每个图像的行数: 28
    每个图像的列数: 28
    开始读取Image数据...
    读取Image数据完毕...
    成功打开图像集...
    幻数(文件格式): 2051
    图片总数: 10000
    每个图像的行数: 28
    每个图像的列数: 28
    开始读取Image数据...
    读取Image数据完毕...
    正确的识别个数 count = 9807
    错误率为...1.93%....
    
    real    3m30.608s
    user    3m29.861s
    sys 0m0.164s
    

    相关文章

      网友评论

          本文标题:opencv中SVM训练mnist手写体

          本文链接:https://www.haomeiwen.com/subject/gesofrtx.html