美文网首页
c++实现lenet推理模型

c++实现lenet推理模型

作者: 一路向后 | 来源:发表于2024-01-13 21:02 被阅读0次

    1.tensor.h

    #ifndef _CONVNET_TENSOR_H_
    #define _CONVNET_TENSOR_H_
    
    #include <vector>
    
    namespace convnet {
    
        class Tensor {
        public:
            Tensor();
            Tensor(int a);
            Tensor(int a, int b);
            Tensor(int a, int b, int c);
            ~Tensor();
    
            void resize(int a);
            void resize(int a, int b);
            void resize(int a, int b, int c);
    
            void relu();
            void sigmoid();
            void argmax(int &s);
    
            void set(std::vector<double> &data);
            int size();
    
        private:
            friend class Linear;
            friend class Conv2d;
            friend class MaxPool2d;
            friend class Reshape;
    
            std::vector<int> dim;
            std::vector<double> data;
        };
    
    }
    
    #endif
    

    2.tensor.cpp

    #include <math.h>
    #include <iostream>
    #include "tensor.h"
    
    using namespace std;
    
    convnet::Tensor::Tensor()
    {
        //dim.clear();
    }
    
    convnet::Tensor::Tensor(int a)
    {
        dim.resize(1);
    
        dim[0] = a;
    
        if(dim[0] > 0)
        {
            data.resize(a);
        }
    }
    
    convnet::Tensor::Tensor(int a, int b)
    {
        dim.resize(2);
    
        dim[0] = a;
        dim[1] = b;
    
        if(a*b > 0)
        {
            data.resize(a*b);
        }
    }
    
    convnet::Tensor::Tensor(int a, int b, int c)
    {
        dim.resize(3);
    
        dim[0] = a;
        dim[1] = b;
        dim[2] = c;
    
        if(a*b*c > 0)
        {
            data.resize(a*b*c);
        }
    }
    
    void convnet::Tensor::resize(int a)
    {
        dim.resize(1);
    
        dim[0] = a;
    
        if(dim[0] > 0  && a != data.size())
        {
            data.resize(a);
        }
    }
    
    void convnet::Tensor::resize(int a, int b)
    {
        dim.resize(2);
    
        dim[0] = a;
        dim[1] = b;
    
        if(a*b > 0 && a*b != data.size())
        {
            data.resize(a*b);
        }
    }
    
    void convnet::Tensor::resize(int a, int b, int c)
    {
        if(dim.size() != 0)
        {
            dim.clear();
        }
    
        if(data.size() != 0)
        {
            data.clear();
        }
    
        dim.resize(3);
    
        dim[0] = a;
        dim[1] = b;
        dim[2] = c;
    
        if(a*b*c > 0)
        {
            data.resize(a*b*c);
        }
    }
    
    convnet::Tensor::~Tensor()
    {
        dim.clear();
    }
    
    void convnet::Tensor::set(std::vector<double> &data)
    {
        this->data = data;
    }
    
    int convnet::Tensor::size()
    {
        return data.size();
    }
    
    void convnet::Tensor::relu()
    {
        for(int i=0; i<data.size(); i++)
        {
            if(data[i] < 0)
            {
                data[i] = 0;
            }
        }
    }
    
    void convnet::Tensor::sigmoid()
    {
        for(int i=0; i<data.size(); i++)
        {
            data[i] = 1 / (1+expf(-data[i]));
        }
    }
    
    void convnet::Tensor::argmax(int &s)
    {
        if(dim.size() == 1 && dim[0] > 0)
        {
            int u = 0;
            int i = 1;
    
            for(i=1; i<dim[0]; i++)
            {
                if(data[i] >= data[u])
                {
                    u = i;
                }
            }
    
            s = u;
        }
    }
    

    3.conv2d.h

    #ifndef _CONVNET_CONV2D_H_
    #define _CONVNET_CONV2D_H_
    
    #include "tensor.h"
    
    namespace convnet {
    
        class Conv2d {
        public:
            Conv2d(int iw, int ih, int ic, int ow, int oh, int oc, int kw, int kh);
            Conv2d(int ic, int oc, int kw);
            ~Conv2d();
    
            void set(Tensor *i, Tensor *o);
            void setargs(std::vector<double> &k);
            void setargs(std::vector<double> &k, std::vector<double> &b);
            void forward();
            void print(int type);
    
        private:
            int iw;
            int ih;
            int ic;
            int ow;
            int oh;
            int oc;
            int kw;
            int kh;
            Tensor kernel;
            Tensor bias;
            Tensor *input;
            Tensor *output;
        };
    
    }
    
    #endif
    

    4.conv2d.cpp

    #include <stdio.h>
    #include <iostream>
    #include <cassert>
    #include "conv2d.h"
    
    using namespace std;
    
    convnet::Conv2d::Conv2d(int iw, int ih, int ic, int ow, int oh, int oc, int kw, int kh)
    {
        this->iw = iw;
        this->ih = ih;
        this->ic = ic;
        this->ow = ow;
        this->oh = oh;
        this->oc = oc;
        this->kw = kw;
        this->kh = kh;
    }
    
    convnet::Conv2d::Conv2d(int ic, int oc, int kw)
    {
        this->iw = 0;
        this->ih = 0;
        this->ic = ic;
        this->ow = 0;
        this->oh = 0;
        this->oc = oc;
        this->kw = kw;
        this->kh = kw;
    }
    
    convnet::Conv2d::~Conv2d()
    {
    
    }
    
    void convnet::Conv2d::set(Tensor *i, Tensor *o)
    {
        input = i;
        output = o;
    
        if(input->size() != iw * ih * ic)
        {
            if(iw * ih != 0)
            {
                input->resize(ic, ih, iw);
            }
            else
            {
                if(input->dim.size() == 3)
                {
                    iw = input->dim[1];
                    ih = input->dim[2];
    
                    //cout << "iw: " << iw << ", " << "ih: " << ih << endl;
                }
            }
        }
    
        if(output->size() != ow * oh * oc || ow * oh == 0)
        {
            if(ow * oh == 0)
            {
                ow = abs(iw - kw + 1);
                oh = abs(ih - kh + 1);
    
                //cout << "ow: " << ow << ", " << "oh: " << oh << ", oc: " << oc << ", ic: " << ic << endl;
            }
    
            output->resize(oc, oh, ow);
        }
    }
    
    void convnet::Conv2d::setargs(std::vector<double> &k)
    {
        if(kw * kh == k.size())
        {
            kernel.resize(kw * kh);
            kernel.set(k);
        }
    }
    
    void convnet::Conv2d::setargs(std::vector<double> &k, std::vector<double> &b)
    {
        bias.dim.resize(1);
    
        bias.dim[0] = b.size();
        bias.data = b;
    
        assert(k.size() == kw * kh * ic * oc);
        assert(b.size() == oc);
    
        kernel.set(k);
    }
    
    void convnet::Conv2d::forward()
    {
        int i, j, k, l, p, q;
    
        for(i=0; i<oc; i++)
        {
            for(j=0; j<oh; j++)
            {
                for(k=0; k<ow; k++)
                {
                    output->data[i*oh*ow+j*ow+k] = 0.0;
    
                    for(l=0; l<ic; l++)
                    {
                        //printf("i = %d, kh = %d, kw = %d\n", i, kh, kw);
    
                        for(p=0; p<kh; p++)
                        {
                            for(q=0; q<kw; q++)
                            {
                                //printf("p = %d, q = %d, u = %d, k = %d\n", p, q, i*kh*kw+p*kw+q, kernel.data[i*kh*kw+p*kw+q]);
    
                                output->data[i*oh*ow+j*ow+k] += (kernel.data[(i*ic+l)*kh*kw+p*kw+q] * input->data[l*ih*iw+(j+p)*iw+(k+q)]);
                                //output->data[i*oh*ow+j*ow+k] += kernel.data[(l)*kh*kw+p*kw+q] * input->data[l*ih*iw+(j+p)*iw+(k+q)];
                            }
                        }
    
                        //output->data[i*oh*ow+j*ow+k] += bias.data[i];
                    }
    
                    output->data[i*oh*ow+j*ow+k] += bias.data[i];
                }
            }
        }
    }
    
    void convnet::Conv2d::print(int type)
    {
        if(type == 0)
        {
            int m = iw * ih * ic;
    
            for(int i=0; i<m; i++)
            {
                printf("%.6lf ", input->data[i]);
            }
    
            printf("\n");
        }
        else if(type == 1)
        {
            int n = ow * oh * oc;
    
            for(int i=0; i<n; i++)
            {
                printf("%.6lf ", output->data[i]);
            }
    
            printf("\n");
        }
    }
    

    5.maxpool2d.h

    #define _CONVNET_MAXPOOL2D_H_
    
    #include "tensor.h"
    
    namespace convnet {
    
        class MaxPool2d {
        public:
            MaxPool2d(int iw, int ih, int w, int h, int stride);
            MaxPool2d(int w, int h, int stride);
            MaxPool2d(int w, int stride);
            ~MaxPool2d();
    
            void set(Tensor *i, Tensor *o);
            void forward();
            void print(int type);
    
        private:
            int ic;
            int iw;
            int ih;
            int ow;
            int oh;
            int oc;
            int kw;
            int kh;
            int stride;
            Tensor *input;
            Tensor *output;
        };
    
    }
    
    #endif
    

    6.maxpool2d.cpp

    #include <stdio.h>
    #include <iostream>
    #include "maxpool2d.h"
    
    using namespace std;
    
    convnet::MaxPool2d::MaxPool2d(int iw, int ih, int w, int h, int stride)
    {
        this->iw = iw;
        this->ih = ih;
        this->kw = kw;
        this->kh = kh;
        this->stride = stride;
    
        ow = iw / stride;
        oh = ih / stride;
    }
    
    convnet::MaxPool2d::MaxPool2d(int w, int h, int stride)
    {
        this->iw = 0;
        this->ih = 0;
        this->ow = 0;
        this->oh = 0;
        this->kw = w;
        this->kh = h;
        this->stride = stride;
    }
    
    convnet::MaxPool2d::MaxPool2d(int w, int stride)
    {
        this->iw = 0;
        this->ih = 0;
        this->ic = 0;
        this->ow = 0;
        this->oh = 0;
        this->oc = 0;
        this->kw = w;
        this->kh = w;
        this->stride = stride;
    }
    
    convnet::MaxPool2d::~MaxPool2d()
    {
    
    }
    
    void convnet::MaxPool2d::set(Tensor *i, Tensor *o)
    {
        if(iw != 0 && ih != 0)
        {
            if(iw * ih != i->size())
            {
                i->resize(iw, ih);
            }
        }
        else
        {
            if(i->dim.size() == 2)
            {
                iw = i->dim[0];
                ih = i->dim[1];
                ic = 1;
                oc = 1;
            }
            else if(i->dim.size() == 3)
            {
                ic = i->dim[0];
                            ih = i->dim[1];
                            iw = i->dim[2];
                            oc = ic;
    
                //cout << "dddd" << iw << ", " << ih << ", ic: " << ic << endl;
            }
    
            int j = 0;
    
            ow = iw / stride;
            oh = ih / stride;
        }
    
        if(ow * oh * oc != 0 && ow * oh * oc != o->size())
        {
            o->resize(oc, oh, ow);
            //cout << "eee " << ow << ", " << oh << ", oc: " << oc << endl;
        }
    
        input = i;
        output = o;
    }
    
    void convnet::MaxPool2d::forward()
    {
        int i, j, k;
        int p, q;
    
        for(i=0; i<oc; i++)
        {
            for(j=0; j<oh; j++)
            {
                for(k=0; k<ow; k++)
                {
                    output->data[i*oh*ow+j*ow+k] = input->data[i*ih*iw+j*stride*iw+k*stride];
    
                    for(p=0; p<stride; p++)
                    {
                        for(q=0; q<stride; q++)
                        {
                            if(input->data[i*ih*iw+(j*stride+p)*iw+(k*stride+q)] > output->data[i*oh*ow+j*ow+k])
                            {
                                output->data[i*oh*ow+j*ow+k] = input->data[i*ih*iw+(j*stride+p)*iw+(k*stride+q)];
                            }
                        }
                    }
                }
            }
        }
    }
    
    void convnet::MaxPool2d::print(int type)
    {
        if(type == 0)
        {
            int m = ic * ih * iw;
    
            for(int i=0; i<m; i++)
            {
                printf("%.6lf ", input->data[i]);
            }
    
            printf("\n");
        }
        else if(type == 1)
        {
            int n = oc * oh * ow;
    
            for(int i=0; i<n; i++)
            {
                printf("%.6lf ", output->data[i]);
            }
    
            printf("\n");
        }
    }
    

    7.reshape.h

    #ifndef _CONVNET_RESHAPE_H_
    #define _CONVNET_RESHAPE_H_
    
    #include "tensor.h"
    
    namespace convnet {
    
        class Reshape {
        public:
            Reshape(int a, int b);
            ~Reshape();
    
            void set(Tensor *input, Tensor *output);
            void forward();
            void print(int type);
    
        private:
            std::vector<int> dim;
            Tensor *input;
            Tensor *output;
        };
    
    }
    
    #endif
    

    8.reshape.cpp

    #include <stdio.h>
    #include <math.h>
    #include "reshape.h"
    #include <cassert>
    
    convnet::Reshape::Reshape(int a, int b)
    {
        dim.resize(2);
    
        dim[0] = a;
        dim[1] = b;
    }
    
    convnet::Reshape::~Reshape()
    {
    
    }
    
    void convnet::Reshape::set(Tensor *input, Tensor *output)
    {
        if(dim.size() > 0)
        {
            if(dim[0] == -1 && dim[1] > 0)
            {
                //printf("size: %d\n", input->dim.size());
                //printf("dim: %d\n", input->dim[0] * input->dim[1]);
    
                if(input->dim.size() == 3)
                {
                    assert(input->dim[0] * input->dim[1] * input->dim[2] == dim[1]);
                }
                else if(input->dim.size() == 2)
                {
                    assert(input->dim[0] * input->dim[1] == dim[1]);
                }
    
                output->resize(dim[1]);
            }
    
            this->input = input;
            this->output = output;
        }
    }
    
    void convnet::Reshape::forward()
    {
        output->data = input->data;
    }
    
    void convnet::Reshape::print(int type)
    {
        if(type == 0)
        {
            int m = abs(dim[0] * dim[1]);
    
            //printf("input size: %p\n", input);
    
            for(int i=0; i<m; i++)
            {
                printf("%.6lf ", input->data[i]);
            }
    
            printf("\n");
        }
        else if(type == 1)
        {
            int n = abs(dim[0] * dim[1]);
    
            for(int i=0; i<n; i++)
            {
                printf("%.6lf ", output->data[i]);
            }
    
            printf("\n");
        }
    }
    

    9.linear.h

    #ifndef _CONVNET_LINEAR_H_
    #define _CONVNET_LINEAR_H_
    
    #include "tensor.h"
    
    namespace convnet {
    
        class Linear {
        public:
            Linear(int m, int n);
    
            void set(Tensor *i, Tensor *o);
            void forward();
            void setargs(std::vector<double> &w, std::vector<double> &b);
            void print(int type);
    
        private:
            int m;
            int n;
            Tensor weight;
            Tensor bias;
            Tensor *input;
            Tensor *output;
        };
    
    }
    
    #endif
    

    10.linear.cpp

    #include <stdio.h>
    #include <cassert>
    #include "linear.h"
    
    convnet::Linear::Linear(int m, int n)
    {
        this->m = m;
        this->n = n;
    
        weight.resize(m, n);
        bias.resize(n);
    }
    
    void convnet::Linear::set(Tensor *i, Tensor *o)
    {
        //printf("m = %d, %d\n", m, i->size());
    
        if(m != i->size())
        {
            i->resize(m);
        }
    
        if(n != o->size())
        {
            o->resize(n);
        }
    
        input = i;
        output = o;
    }
    
    void convnet::Linear::setargs(std::vector<double> &w, std::vector<double> &b)
    {
        assert(w.size() == m * n && b.size() == n);
        {
            weight.data = w;
            bias.data = b;
        }
    }
    
    void convnet::Linear::forward()
    {
        for(int out=0; out<n; out++)
        {
            output->data[out] = 0.0;
    
            for(int in=0; in<m; in++)
            {
                output->data[out] += input->data[in] * weight.data[m*out + in];
            }
    
            output->data[out] += bias.data[out];
        }
    }
    
    void convnet::Linear::print(int type)
    {
        if(type == 0)
        {
            for(int i=0; i<m; i++)
            {
                printf("%.6lf ", input->data[i]);
            }
    
            printf("\n");
        }
        else if(type == 1)
        {
            for(int i=0; i<n; i++)
            {
                printf("%.6lf ", output->data[i]);
            }
    
            printf("\n");
        }
    }
    

    11.main.cpp

    #include <iostream>
    #include <cstring>
    #include <cassert>
    #include "conv2d.h"
    #include "maxpool2d.h"
    #include "linear.h"
    #include "reshape.h"
    #include "pugixml.hpp"
    
    using namespace std;
    using namespace convnet;
    
    void get_numbers(const std::string &line, std::vector<int> &s)
    {
        char num[32];
        int len = line.length();
        int i = 0;
        int j = 0;
        int k = 0;
        int flag = 0;
    
        while(i < len)
        {
            flag = 0;
    
            if(line[i] == '+' || line[i] == '-' || (line[i] >= '0' && line[i] <= '9'))
            {
                flag = 1;
    
                num[j++] = line[i];
            }
            else if(isspace(line[i]))
            {
    
            }
            else
            {
                break;
            }
    
            i++;
    
            if(flag == 1)
            {
                while(line[i] == '+' || line[i] == '-' || (line[i] >= '0' && line[i] <= '9'))
                {
                    num[j++] = line[i];
    
                    i++;
                }
    
                num[j] = 0x00;
    
                if(i < len && line[i] == ',')
                {
                    i++;
                }
    
                k = atoi(num);
    
                j = 0;
    
                s.push_back(k);
            }
        }
    }
    
    void get_numbers(const std::string &line, std::vector<double> &s)
    {
        char num[32];
        int len = line.length();
        int i = 0;
        int j = 0;
        double k = 0;
        int flag = 0;
    
        while(i < len)
        {
            flag = 0;
    
            if(line[i] == '+' || line[i] == '-' || line[i] == 'e' || line[i] == '.' || (line[i] >= '0' && line[i] <= '9'))
            {
                flag = 1;
    
                num[j++] = line[i];
            }
            else if(isspace(line[i]))
            {
    
            }
            else
            {
                break;
            }
    
            i++;
    
            if(flag == 1)
            {
                while(line[i] == '+' || line[i] == '-' || line[i] == 'e' || line[i] == '.' || (line[i] >= '0' && line[i] <= '9'))
                {
                    num[j++] = line[i];
    
                    i++;
                }
    
                num[j] = 0x00;
    
                if(i < len && line[i] == ',')
                {
                    i++;
                }
    
                char *endptr = NULL;
    
                k = strtod(num, &endptr);
    
                j = 0;
    
                s.push_back(k);
            }
        }
    }
    
    int main()
    {
        Tensor *input = NULL;
        std::vector<Tensor *> outputs;
        std::vector<Conv2d *> conv2ds;
        std::vector<MaxPool2d *> maxpool2ds;
        std::vector<Reshape *> reshapes;
        std::vector<Linear *> linears;
        std::vector<int> actfuncs;
        std::vector<std::vector<int>> args;
        std::vector<int> args_brh;
        std::vector<std::string> types;
        std::vector<std::string> names;
        std::vector<int> intypes;
        std::vector<double> nums[2];
        int k = 0;
        int u = 0;
        int v = 0;
        int w = 0;
        int l = 0;
    
        pugi::xml_document doc;
        pugi::xml_parse_result result = doc.load_file("lenet.xml");
    
        if(!result)
        {
            return -1;
        }
    
        pugi::xml_node xnodes = doc.child("lenet");
    
        for(pugi::xml_node xnode = xnodes.first_child(); xnode != NULL; xnode = xnode.next_sibling())
        {
            std::string type = xnode.child("type").text().as_string();
            std::string name = xnode.child("name").text().as_string();
    
            //cout << "Type: " << type << endl;
    
            if(type == "input")
            {
                std::string value = xnode.child("value").text().as_string();
    
                get_numbers(value, args_brh);
    
                if(args_brh.size() == 3)
                {
                    types.push_back(type);
                    names.push_back(name);
                    args.push_back(args_brh);
    
                    if(xnode.child("actfunc") != NULL)
                    {
                        std::string func = xnode.child("actfunc").text().as_string();
    
                        if(func == "relu")
                        {
                            actfuncs.push_back(1);
                        }
                        else
                        {
                            actfuncs.push_back(-1);
                        }
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
    
                k++;
    
                args_brh.clear();
            }
            else if(type == "Conv2d")
            {
                std::string value = xnode.child("value").text().as_string();
    
                get_numbers(value, args_brh);
    
                if(args_brh.size() == 3)
                {
                    types.push_back(type);
                    names.push_back(name);
                    args.push_back(args_brh);
    
                    if(xnode.child("actfunc") != NULL)
                    {
                        std::string func = xnode.child("actfunc").text().as_string();
    
                        if(func == "relu")
                        {
                            actfuncs.push_back(1);
                        }
                        else
                        {
                            actfuncs.push_back(-1);
                        }
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
    
                k++;
    
                args_brh.clear();
            }
            else if(type == "MaxPool2d")
            {
                std::string value = xnode.child("value").text().as_string();
    
                get_numbers(value, args_brh);
    
                if(args_brh.size() == 2)
                {
                    types.push_back(type);
                    names.push_back(name);
                    args.push_back(args_brh);
    
                    if(xnode.child("actfunc") != NULL)
                    {
                        std::string func = xnode.child("actfunc").text().as_string();
    
                        if(func == "relu")
                        {
                            actfuncs.push_back(1);
                        }
                        else
                        {
                            actfuncs.push_back(-1);
                        }
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
    
                k++;
    
                args_brh.clear();
            }
            else if(type == "reshape")
            {
                std::string value = xnode.child("value").text().as_string();
    
                get_numbers(value, args_brh);
    
                if(args_brh.size() == 2)
                {
                    types.push_back(type);
                    names.push_back(name);
                    args.push_back(args_brh);
    
                    if(xnode.child("actfunc") != NULL)
                    {
                        std::string func = xnode.child("actfunc").text().as_string();
    
                        if(func == "relu")
                        {
                            actfuncs.push_back(1);
                        }
                        else
                        {
                            actfuncs.push_back(-1);
                        }
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
    
                k++;
    
                args_brh.clear();
            }
            else if(type == "Linear")
            {
                std::string value = xnode.child("value").text().as_string();
    
                get_numbers(value, args_brh);
    
                if(args_brh.size() == 2)
                {
                    types.push_back(type);
                    names.push_back(name);
                    args.push_back(args_brh);
    
                    if(xnode.child("actfunc") != NULL)
                    {
                        std::string func = xnode.child("actfunc").text().as_string();
    
                        if(func == "relu")
                        {
                            actfuncs.push_back(1);
                        }
                        else
                        {
                            actfuncs.push_back(-1);
                        }
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
    
                k++;
    
                args_brh.clear();
            }
        }
    
        for(k=0; k<types.size(); k++)
        {
            string type = types[k];
    
            //cout << "type: " << type << ", name: " << names[k] << endl;
    
            /*if(actfuncs[k] == 1)
            {
                cout << "call function relu" << endl;
            }*/
    
            if(type == "input")
            {
                input = new Tensor(args[k][0], args[k][1], args[k][2]);
    
                intypes.push_back(1);
            }
            else if(type == "Conv2d")
            {
                conv2ds.push_back(new Conv2d(args[k][0], args[k][1], args[k][2]));
                outputs.push_back(new Tensor());
    
                u = conv2ds.size() - 1;
                v = outputs.size() - 1;
    
                if(v == 0)
                {
                    conv2ds[u]->set(input, outputs[v]);
                }
                else
                {
                    conv2ds[u]->set(outputs[v-1], outputs[v]);
                }
    
                intypes.push_back(2);
            }
            else if(type == "MaxPool2d")
            {
                maxpool2ds.push_back(new MaxPool2d(args[k][0], args[k][1]));
                outputs.push_back(new Tensor());
    
                u = maxpool2ds.size() - 1;
                v = outputs.size() - 1;
    
                maxpool2ds[u]->set(outputs[v-1], outputs[v]);
    
                intypes.push_back(3);
            }
            else if(type == "reshape")
            {
                reshapes.push_back(new Reshape(args[k][0], args[k][1]));
                outputs.push_back(new Tensor());
    
                u = reshapes.size() - 1;
                v = outputs.size() - 1;
    
                reshapes[u]->set(outputs[v-1], outputs[v]);
    
                intypes.push_back(4);
            }
            else if(type == "Linear")
            {
                linears.push_back(new Linear(args[k][0], args[k][1]));
                outputs.push_back(new Tensor());
    
                u = linears.size() - 1;
                v = outputs.size() - 1;
    
                linears[u]->set(outputs[v-1], outputs[v]);
    
                intypes.push_back(5);
            }
        }
    
        FILE *fp = fopen("lenet.mod", "r");
        char line[128];
        int nodes = 0;
    
        if(fp == NULL)
        {
            return -1;
        }
    
        memset(line, 0x00, sizeof(line));
    
        fread(line, 1, 6, fp);
        fread(&nodes, 1, 4, fp);
    
        if(strcmp(line, "lenet") != 0)
        {
            return -1;
        }
    
        u = 0;
        v = 0;
        w = 0;
    
        for(k=0; k<types.size(); k++)
        {
            if(intypes[k] == 1)
            {
                //cout << "do nothing" << endl;
            }
            else if(intypes[k] == 2)
            {
                //cout << "deal conv2d" << endl;
    
                nums[0].clear();
    
                memset(line, 0x00, sizeof(line));
    
                fread(line, 1, names[k].length()+8, fp);
                fread(&nodes, 1, 4, fp);
    
                //printf("line = %s, number = %d\n", line, nodes);
    
                nums[0].resize(nodes);
    
                for(int i=0; i<nodes; i++)
                {
                    fread(&nums[0][i], 1, sizeof(double), fp);
    
                    //printf("nums = %.6lf\n", nums[0][i]);
                }
    
                nums[1].clear();
    
                memset(line, 0x00, sizeof(line));
    
                fread(line, 1, names[k].length()+6, fp);
                fread(&nodes, 1, 4, fp);
    
                //printf("line = %s, number = %d\n", line, nodes);
    
                nums[1].resize(nodes);
    
                for(int i=0; i<nodes; i++)
                {
                    fread(&nums[1][i], 1, sizeof(double), fp);
    
                    //printf("nums = %.6lf\n", nums[1][i]);
                }
    
                conv2ds[u++]->setargs(nums[0], nums[1]);
            }
            else if(intypes[k] == 5)
            {
                //cout << "deal linear" << endl;
    
                nums[0].clear();
    
                memset(line, 0x00, sizeof(line));
    
                fread(line, 1, names[k].length()+8, fp);
                fread(&nodes, 1, 4, fp);
    
                //printf("line = %s, number = %d\n", line, nodes);
    
                nums[0].resize(nodes);
    
                for(int i=0; i<nodes; i++)
                {
                    fread(&nums[0][i], 1, sizeof(double), fp);
    
                    //printf("nums = %.6lf\n", nums[0][i]);
                }
    
                nums[1].clear();
    
                memset(line, 0x00, sizeof(line));
    
                fread(line, 1, names[k].length()+6, fp);
                fread(&nodes, 1, 4, fp);
    
                //printf("line = %s, number = %d\n", line, nodes);
    
                nums[1].resize(nodes);
    
                for(int i=0; i<nodes; i++)
                {
                    fread(&nums[1][i], 1, sizeof(double), fp);
    
                    //printf("nums = %.6lf\n", nums[1][i]);
                }
    
                linears[v++]->setargs(nums[0], nums[1]);
            }
        }
    
        fclose(fp);
    
        nums[0].clear();
        nums[1].clear();
    
        fp = fopen("input.txt", "r");
    
        if(fp == NULL)
        {
            return -1;
        }
    
        while(!feof(fp))
        {
            memset(line, 0x00, sizeof(line));
    
            fgets(line, 127, fp);
    
            get_numbers(line, nums[0]);
        }
    
        fclose(fp);
    
        //cout << "nums = " << nums[0].size() << endl;
    
        assert(nums[0].size() == input->size());
    
        input->set(nums[0]);
    
        nums[0].clear();
    
        u = 0;
        v = 0;
        w = 0;
        l = 0;
    
        //printf("types.size = %d, intypes.size = %d\n", types.size(), intypes.size());
    
        for(k=1; k<types.size(); k++)
        {
            switch(intypes[k])
            {
                case 2:
                    conv2ds[u]->forward();
                    if(actfuncs[k] == 1)
                    {
                        outputs[k-1]->relu();
                        //cout << "call relu" << endl;
                    }
                    //if(u==1)
                    //conv2ds[u]->print(1);
                    u++;
                    break;
                case 3:
                    maxpool2ds[v]->forward();
                    if(actfuncs[k] == 1)
                    {
                        outputs[k-1]->relu();
                    }
                    //if(v==1)
                    //maxpool2ds[v]->print(1);
                    v++;
                    break;
    
                case 4:
                    reshapes[w]->forward();
                    if(actfuncs[k] == 1)
                    {
                        outputs[k-1]->relu();
                    }
                    //cout << "reshape" << endl;
                    //reshapes[w]->print(0);
                    w++;
                    break;
    
                case 5:
                    linears[l]->forward();
                    if(actfuncs[k] == 1)
                    {
                        outputs[k-1]->relu();
                    }
                    //linears[l]->print(1);
                    l++;
                    break;
                    
                default:
                    break;
            }
        }
    
        k = types.size() - 2;
    
        u = 0;
    
        outputs[k]->argmax(u);
    
        printf("predict: %d\n", u);
    
        if(input != NULL)
        {
            delete input;
    
            input = NULL;
        }
    
        for(k=0; k<conv2ds.size(); k++)
        {
            if(conv2ds[k] != NULL)
            {
                delete conv2ds[k];
    
                conv2ds[k] = NULL;
            }
        }
    
        conv2ds.clear();
    
        for(k=0; k<outputs.size(); k++)
        {
            if(outputs[k] != NULL)
            {
                delete outputs[k];
    
                outputs[k] = NULL;
            }
        }
    
        outputs.clear();
    
        for(k=0; k<maxpool2ds.size(); k++)
        {
            if(maxpool2ds[k] != NULL)
            {
                delete maxpool2ds[k];
    
                maxpool2ds[k] = NULL;
            }
        }
    
        maxpool2ds.clear();
    
        for(k=0; k<reshapes.size(); k++)
        {
            if(reshapes[k] != NULL)
            {
                delete reshapes[k];
    
                reshapes[k] = NULL;
            }
        }
    
        reshapes.clear();
    
        for(k=0; k<linears.size(); k++)
        {
            if(linears[k] != NULL)
            {
                delete linears[k];
    
                linears[k] = NULL;
            }
        }
    
        linears.clear();
    
        return 0;
    }
    

    12.Makefile

    CXX=g++
    STD=-std=c++11
    DEBUG=-g
    LDFLAGS=
    CXXFLASG=
    OBJS=linear.o tensor.o conv2d.o maxpool2d.o pugixml.o reshape.o
    
    lenet: main.cpp $(OBJS)
        $(CXX) $(DEBUG) -o lenet main.cpp $(OBJS) $(STD) $(LDFLAGS)
    
    linear.o: linear.cpp
        $(CXX) $(DEBUG) -c linear.cpp $(STD) $(CXXLFAGS)
    
    tensor.o: tensor.cpp tensor.h
        $(CXX) $(DEBUG) -c tensor.cpp $(STD) $(CXXLFAGS)
    
    conv2d.o: conv2d.cpp conv2d.h
        $(CXX) $(DEBUG) -c conv2d.cpp $(STD) $(CXXLFAGS)
    
    maxpool2d.o: maxpool2d.cpp maxpool2d.h
        $(CXX) $(DEBUG) -c maxpool2d.cpp $(STD) $(CXXLFAGS)
    
    reshape.o: reshape.cpp reshape.h
        $(CXX) $(DEBUG) -c reshape.cpp $(STD) $(CXXLFAGS)
    
    pugixml.o: pugixml.cpp pugixml.hpp
        $(CXX) $(DEBUG) -c pugixml.cpp $(STD) $(CXXLFAGS)
    
    clean:
        rm -rf lenet
        rm -rf $(OBJS)
    

    13.lenet.xml

    <?xml version="1.0" encoding="UTF-8"?>
    <lenet>
        <node>
            <type>input</type>
            <name>input</name>
            <value>1, 28, 28</value>
        </node>
        <node>
            <type>Conv2d</type>
            <name>conv1</name>
            <value>1, 6, 5</value>
            <actfunc>relu</actfunc>
        </node>
        <node>
            <type>MaxPool2d</type>
            <name>pool1</name>
            <value>2, 2</value>
        </node>
        <node>
            <type>Conv2d</type>
            <name>conv2</name>
            <value>6, 16, 5</value>
            <actfunc>relu</actfunc>
        </node>
        <node>
            <type>MaxPool2d</type>
            <name>pool2</name>
            <value>2, 2</value>
        </node>
        <node>
            <type>reshape</type>
            <name>view</name>
            <value>-1, 256</value>
        </node>
        <node>
            <type>Linear</type>
            <name>fc1</name>
            <value>256, 120</value>
            <actfunc>relu</actfunc>
        </node>
        <node>
            <type>Linear</type>
            <name>fc2</name>
            <value>120, 84</value>
            <actfunc>relu</actfunc>
        </node>
        <node>
            <type>Linear</type>
            <name>fc3</name>
            <value>84, 10</value>
        </node>
    </lenet>
    

    14.input.txt

    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.6450,
     1.9305,  1.5996,  1.4978,  0.3395,  0.0340, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  2.4015,
     2.8088,  2.8088,  2.8088,  2.8088,  2.6433,  2.0960,  2.0960,
     2.0960,  2.0960,  2.0960,  2.0960,  2.0960,  2.0960,  1.7396,
     0.2377, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.4286,
     1.0268,  0.4922,  1.0268,  1.6505,  2.4651,  2.8088,  2.4396,
     2.8088,  2.8088,  2.8088,  2.7578,  2.4906,  2.8088,  2.8088,
     1.3577, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.2078,  0.4159, -0.2460,
     0.4286,  0.4286,  0.4286,  0.3268, -0.1569,  2.5797,  2.8088,
     0.9250, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242,  0.6322,  2.7960,  2.2360,
    -0.1951, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.1442,  2.5415,  2.8215,  0.6322,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242,  1.2177,  2.8088,  2.6051,  0.1358,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242,  0.3268,  2.7451,  2.8088,  0.3649, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242,  1.2686,  2.8088,  1.9560, -0.3606, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.3097,  2.1851,  2.7324,  0.3140, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242,  1.1795,  2.8088,  1.8923, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
     0.5304,  2.7706,  2.6306,  0.3013, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.1824,
     2.3887,  2.8088,  1.6887, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.3860,  2.1596,
     2.8088,  2.3633,  0.0213, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.0595,  2.8088,
     2.8088,  0.5559, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.0296,  2.4269,  2.8088,
     1.0395, -0.4115, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242,  1.2686,  2.8088,  2.8088,
     0.2377, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242,  0.3522,  2.6560,  2.8088,  2.8088,
     0.2377, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242,  1.1159,  2.8088,  2.8088,  2.3633,
     0.0849, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242,  1.1159,  2.8088,  2.2105, -0.1951,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
    -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242
    

    15.编译源码

    $ make
    

    16.运行及其结果

    $ ./lenet
    predict: 7
    

    相关文章

      网友评论

          本文标题:c++实现lenet推理模型

          本文链接:https://www.haomeiwen.com/subject/mshqodtx.html