美文网首页
c++实现rnn推理模型

c++实现rnn推理模型

作者: 一路向后 | 来源:发表于2024-01-20 16:53 被阅读0次

    1.问题

    用RNN实现输入一个字母,预测出下一个字母:
    输入a, 预测出b
    输入b, 预测出c
    输入c, 预测出d
    输入d, 预测出e
    输入e, 预测出a

    2.tensor.h

    #ifndef _CONVNET_TENSOR_H_
    #define _CONVNET_TENSOR_H_
    
    #include <vector>
    
    typedef double Real;
    
    namespace convnet {
    
        class Tensor {
        public:
            Tensor();
            Tensor(int a);
            Tensor(int a, int b);
            Tensor(int a, int b, int c);
            ~Tensor();
    
            void resize(int a);
            void resize(int a, int b);
            void resize(int a, int b, int c);
    
            void relu();
            void tanh();
            void sigmoid();
            void argmax(int &s);
    
            void set(std::vector<Real> &data);
            int size();
    
        private:
            friend class Linear;
            friend class Conv2d;
            friend class MaxPool2d;
            friend class Reshape;
            friend class Rnn;
    
            std::vector<int> dim;
            std::vector<Real> data;
        };
    
    }
    
    #endif
    

    3.tensor.cpp

    #include <math.h>
    #include <iostream>
    #include "tensor.h"
    
    using namespace std;
    
    convnet::Tensor::Tensor()
    {
        //dim.clear();
    }
    
    convnet::Tensor::Tensor(int a)
    {
        dim.resize(1);
    
        dim[0] = a;
    
        if(dim[0] > 0)
        {
            data.resize(a);
        }
    }
    
    convnet::Tensor::Tensor(int a, int b)
    {
        dim.resize(2);
    
        dim[0] = a;
        dim[1] = b;
    
        if(a*b > 0)
        {
            data.resize(a*b);
        }
    }
    
    convnet::Tensor::Tensor(int a, int b, int c)
    {
        dim.resize(3);
    
        dim[0] = a;
        dim[1] = b;
        dim[2] = c;
    
        if(a*b*c > 0)
        {
            data.resize(a*b*c);
        }
    }
    
    void convnet::Tensor::resize(int a)
    {
        dim.resize(1);
    
        dim[0] = a;
    
        if(dim[0] > 0  && a != data.size())
        {
            data.resize(a);
        }
    }
    
    void convnet::Tensor::resize(int a, int b)
    {
        dim.resize(2);
    
        dim[0] = a;
        dim[1] = b;
    
        if(a*b > 0 && a*b != data.size())
        {
            data.resize(a*b);
        }
    }
    
    void convnet::Tensor::resize(int a, int b, int c)
    {
        if(dim.size() != 0)
        {
            dim.clear();
        }
    
        if(data.size() != 0)
        {
            data.clear();
        }
    
        dim.resize(3);
    
        dim[0] = a;
        dim[1] = b;
        dim[2] = c;
    
        if(a*b*c > 0)
        {
            data.resize(a*b*c);
        }
    }
    
    convnet::Tensor::~Tensor()
    {
        dim.clear();
    }
    
    void convnet::Tensor::set(std::vector<Real> &data)
    {
        this->data = data;
    }
    
    int convnet::Tensor::size()
    {
        return data.size();
    }
    
    void convnet::Tensor::relu()
    {
        for(int i=0; i<data.size(); i++)
        {
            if(data[i] < 0)
            {
                data[i] = 0;
            }
        }
    }
    
    void convnet::Tensor::sigmoid()
    {
        for(int i=0; i<data.size(); i++)
        {
            data[i] = 1 / (1+expf(-data[i]));
        }
    }
    
    void convnet::Tensor::argmax(int &s)
    {
        if(dim.size() == 1 && dim[0] > 0)
        {
            int u = 0;
            int i = 1;
    
            for(i=1; i<dim[0]; i++)
            {
                if(data[i] >= data[u])
                {
                    u = i;
                }
            }
    
            s = u;
        }
    }
    
    void convnet::Tensor::tanh()
    {
        if(dim.size() == 1 && dim[0] > 0)
        {
            int i = 1;
    
            for(i=0; i<dim[0]; i++)
            {
                data[i] = std::tanh(data[i]);
            }
        }
    }
    

    4.rnn.h

    #ifndef _CONVNET_RNN_H_
    #define _CONVNET_RNN_H_
    
    #include <string>
    #include "tensor.h"
    
    namespace convnet {
    
        class Rnn {
        public:
            Rnn(int m, int n);
            ~Rnn();
    
            void set(Tensor *i, Tensor *h);
            void setargs(std::vector<Real> &w1, std::vector<Real> &w2, std::vector<Real> &b1, std::vector<Real> &b2);
            void forward();
            void print(int type);
    
        private:
            int m;
            int n;
            Tensor sw;
            Tensor sb;
            Tensor ow;
            Tensor ob;
            Tensor ht;
            Tensor st;
            Tensor *input;
            Tensor *hidden;
        };
    
    }
    
    #endif
    

    5.rnn.cpp

    #include <cassert>
    #include "rnn.h"
    
    convnet::Rnn::Rnn(int m, int n)
    {
        this->m = m;
        this->n = n;
    
        sw.resize(m, n);
        sb.resize(n);
        ow.resize(n, n);
        ob.resize(n);
        st.resize(n);
        ht.resize(n);
    
        for(int s=0; s<n; s++)
        {
            ht.data[s] = 0.0;
        }
    }
    
    convnet::Rnn::~Rnn()
    {
    
    }
    
    void convnet::Rnn::set(Tensor *i, Tensor *h)
    {
        this->input = i;
        this->hidden = h;
    
        if(n != 0 && h->size() != n)
        {
            h->resize(n);
        }
    }
    
    void convnet::Rnn::forward()
    {
        int i, j;
    
        for(i=0; i<n; i++)
        {
            st.data[i] = 0.0;
    
            for(j=0; j<m; j++)
            {
                st.data[i] += input->data[j] * sw.data[m*i + j];
            }
    
            st.data[i] += sb.data[i];
    
            for(j=0; j<n; j++)
            {
                st.data[i] += ht.data[j] * ow.data[n*i + j];
            }
    
            st.data[i] += ob.data[i];
        }
    
        st.tanh();
    
        for(i=0; i<n; i++)
        {
            ht.data[i] = st.data[i];
            hidden->data[i] = st.data[i];
        }
    }
    
    void convnet::Rnn::setargs(std::vector<Real> &w1, std::vector<Real> &w2, std::vector<Real> &b1, std::vector<Real> &b2)
    {
        assert(w1.size() == m * n && b1.size() == n);
        assert(w2.size() == n * n && b2.size() == n);
    
        sw.set(w1);
        ow.set(w2);
        sb.set(b1);
        ob.set(b2);
    }
    
    void convnet::Rnn::print(int type)
    {
        if(type == 0)
        {
            for(int i=0; i<m; i++)
            {
                printf("%.6lf ", input->data[i]);
            }
    
            printf("\n");
        }
        else if(type == 1)
        {
            for(int i=0; i<n; i++)
            {
                printf("%.6lf ", hidden->data[i]);
            }
    
            printf("\n");
        }
    }
    

    6.linear.h

    #ifndef _CONVNET_LINEAR_H_
    #define _CONVNET_LINEAR_H_
    
    #include "tensor.h"
    
    namespace convnet {
    
        class Linear {
        public:
            Linear(int m, int n);
    
            void set(Tensor *i, Tensor *o);
            void forward();
            void setargs(std::vector<Real> &w, std::vector<Real> &b);
            void print(int type);
    
        private:
            int m;
            int n;
            Tensor weight;
            Tensor bias;
            Tensor *input;
            Tensor *output;
        };
    
    }
    
    #endif
    

    7.linear.cpp

    #include <stdio.h>
    #include <cassert>
    #include "linear.h"
    
    convnet::Linear::Linear(int m, int n)
    {
        this->m = m;
        this->n = n;
    
        weight.resize(m, n);
        bias.resize(n);
    }
    
    void convnet::Linear::set(Tensor *i, Tensor *o)
    {
        //printf("m = %d, %d\n", m, i->size());
    
        if(m != i->size())
        {
            i->resize(m);
        }
    
        if(n != o->size())
        {
            o->resize(n);
        }
    
        input = i;
        output = o;
    }
    
    void convnet::Linear::setargs(std::vector<Real> &w, std::vector<Real> &b)
    {
        assert(w.size() == m * n && b.size() == n);
        {
            weight.data = w;
            bias.data = b;
        }
    }
    
    void convnet::Linear::forward()
    {
        for(int out=0; out<n; out++)
        {
            output->data[out] = 0.0;
    
            for(int in=0; in<m; in++)
            {
                output->data[out] += input->data[in] * weight.data[m*out + in];
            }
    
            output->data[out] += bias.data[out];
        }
    }
    
    void convnet::Linear::print(int type)
    {
        if(type == 0)
        {
            for(int i=0; i<m; i++)
            {
                printf("%.6lf ", input->data[i]);
            }
    
            printf("\n");
        }
        else if(type == 1)
        {
            for(int i=0; i<n; i++)
            {
                printf("%.6lf ", output->data[i]);
            }
    
            printf("\n");
        }
    }
    

    8.main.cpp

    #include <iostream>
    #include <vector>
    #include "tensor.h"
    #include "rnn.h"
    #include "linear.h"
    
    using namespace std;
    using namespace convnet;
    
    int main()
    {
        Tensor *input = new Tensor(5);
        Tensor output[2];
        Rnn rnn(5, 10);
        Linear linear(10, 5);
        vector<Real> w1 = {
            1.0875e-01,  1.9007e-01,  3.3171e-02, -4.2167e-01,  2.3260e-01,
            -4.0091e-02,  5.0844e-01, -4.9908e-03, -4.6057e-01, -3.2492e-01,
            -4.4841e-01, -1.9543e-01,  4.1142e-01,  3.3704e-01, -7.6349e-02,
            2.7550e-01,  1.1706e-01,  5.2413e-01,  3.6117e-01, -5.5991e-01,
            3.9840e-01,  2.3983e-02, -3.0162e-01, -6.0204e-02, -1.3524e-01,
            3.9064e-04, -4.0791e-01,  2.9194e-01, -3.2485e-01,  1.1633e-01,
            2.1486e-01,  1.3768e-01,  3.2932e-01, -2.1038e-01,  3.9599e-01,
            -2.4395e-01, -4.5648e-02,  3.3629e-01, -4.9821e-01, -9.9775e-02,
            4.2604e-01, -6.1709e-01,  1.4171e-02, -3.5881e-01, -3.0136e-01,
            -1.1071e-01,  4.3087e-01, -7.1492e-02, -4.4776e-02, -1.9427e-01
        };
        vector<Real> w2 = {
             0.1804, -0.0511, -0.1194, -0.3126,  0.3056,  0.2208,  0.2536, -0.2775, 0.1334, -0.0084,
            -0.0078,  0.0143,  0.0403,  0.1966, -0.0028,  0.0869,  0.0081, -0.2408, 0.0628,  0.0728,
            0.0433, -0.2915,  0.2838, -0.1858, -0.0760, -0.2338,  0.0192,  0.2064, -0.0470, -0.2736,
            -0.1543, -0.0061, -0.0271,  0.1564, -0.1332,  0.2041, -0.0063, -0.0483, 0.3013, -0.0242,
            -0.0377,  0.1239, -0.1080,  0.2230,  0.1908, -0.2534, -0.2355, -0.2026, -0.0397, -0.0283,
            0.3074, -0.1016, -0.2998, -0.2427, -0.0007, -0.1828, -0.0867, -0.2579, 0.2764, -0.1827,
            -0.0062,  0.0415, -0.1900,  0.1646, -0.0817,  0.1933, -0.1867, -0.0074, -0.3107, -0.2211,
            0.0158,  0.3108, -0.0322,  0.0481,  0.2690,  0.1093, -0.2631,  0.2370, -0.1548, -0.2132,
            -0.2503,  0.2321, -0.0190, -0.2398,  0.1281, -0.2103,  0.3047,  0.3008, -0.2617, -0.1564,
            0.0903, -0.2276,  0.1263,  0.0693, -0.2775,  0.2864,  0.1292, -0.3017, 0.1994, -0.1917
        };
        vector<Real> w3 = {
            0.2785, -0.4658, -0.1571, -0.7094, -0.3439,  0.1966,  0.2905, -0.1365, -0.1963, -0.0616,
            0.0765,  0.0545, -0.4994,  0.2200,  0.4290,  0.0951, -0.0870, -0.2526, 0.5923, -0.1664,
            0.3581,  0.4772, -0.1968,  0.1624, -0.0177, -0.2860,  0.0372,  0.0018, -0.4826,  0.4231,
            0.2152,  0.0899,  0.4478,  0.4280, -0.4368,  0.3140,  0.1704,  0.4542, 0.2319,  0.0220,
            -0.2668, -0.4524,  0.2827,  0.3360, -0.0086, -0.2787, -0.1835, -0.5607, -0.2900, -0.1088
        };
        vector<Real> b1 = {0.2237, -0.1793,  0.1222, -0.0701, -0.1727,  0.0800,  0.2516, -0.3360, -0.0056,  0.0087};
        vector<Real> b2 = {0.1713, -0.0580, -0.1222,  0.1427,  0.0238, -0.0174,  0.0815,  0.2246, 0.0837,  0.0762};
        vector<Real> b3 = {-0.1141,  0.2620,  0.0653,  0.0189,  0.0603};
        vector<Real> x1 = {1,0,0,0,0};
        char chx1 = 0x00;
        int y1;
    
        printf("请输入测试字母: ");
    
        scanf("%c", &chx1);
    
        if(chx1 == 'a')
        {
            x1 = {1, 0, 0, 0, 0};
        }
        else if(chx1 == 'b')
        {
            x1 = {0, 1, 0, 0, 0};
        }
        else if(chx1 == 'c')
        {
            x1 = {0, 0, 1, 0, 0};
        }
        else if(chx1 == 'd')
        {
            x1 = {0, 0, 0, 1, 0};
        }
        else if(chx1 == 'e')
        {
            x1 = {0, 0, 0, 0, 1};
        }
        else
        {
            delete input;
    
            return -1;
        }
    
        rnn.set(input, &output[0]);
        linear.set(&output[0], &output[1]);
    
        rnn.setargs(w1, w2, b1, b2);
        linear.setargs(w3, b3);
    
        input->set(x1);
        rnn.forward();
        linear.forward();
    
        output[1].argmax(y1);
    
        printf("预测字母为: %c\n", y1+'a');
    
        delete input;
    
        return 0;
    }
    

    9.Makefile

    CXX=g++
    STD=-std=c++11
    DEBUG=-g
    LDFLAGS=
    CXXFLASG=
    OBJS=tensor.o rnn.o linear.o
    
    rnn: main.cpp $(OBJS)
        $(CXX) $(DEBUG) -o rnn main.cpp $(OBJS) $(STD) $(LDFLAGS)
    
    tensor.o: tensor.cpp tensor.h
        $(CXX) $(DEBUG) -c tensor.cpp $(STD) $(CXXLFAGS)
    
    rnn.o: rnn.cpp rnn.h
        $(CXX) $(DEBUG) -c rnn.cpp $(STD) $(CXXLFAGS)
    
    linear.o: linear.cpp linear.h
        $(CXX) $(DEBUG) -c linear.cpp $(STD) $(CXXFLAGS)
    
    clean:
        rm -rf rnn
        rm -rf $(OBJS)
    

    10.编译源码

    $ make
    

    11.运行及其结果

    $ ./rnn 
    请输入测试字母: a
    预测字母为: b
    $ ./rnn 
    请输入测试字母: b
    预测字母为: c
    $ ./rnn 
    请输入测试字母: c
    预测字母为: d
    $ ./rnn 
    请输入测试字母: d
    预测字母为: e
    $ ./rnn 
    请输入测试字母: e
    预测字母为: a
    

    相关文章

      网友评论

          本文标题:c++实现rnn推理模型

          本文链接:https://www.haomeiwen.com/subject/tjdsodtx.html