1.问题
用RNN实现输入一个字母,预测出下一个字母:
输入a, 预测出b
输入b, 预测出c
输入c, 预测出d
输入d, 预测出e
输入e, 预测出a
2.tensor.h
#ifndef _CONVNET_TENSOR_H_
#define _CONVNET_TENSOR_H_
#include <vector>
typedef double Real;
namespace convnet {
class Tensor {
public:
Tensor();
Tensor(int a);
Tensor(int a, int b);
Tensor(int a, int b, int c);
~Tensor();
void resize(int a);
void resize(int a, int b);
void resize(int a, int b, int c);
void relu();
void tanh();
void sigmoid();
void argmax(int &s);
void set(std::vector<Real> &data);
int size();
private:
friend class Linear;
friend class Conv2d;
friend class MaxPool2d;
friend class Reshape;
friend class Rnn;
std::vector<int> dim;
std::vector<Real> data;
};
}
#endif
3.tensor.cpp
#include <math.h>
#include <iostream>
#include "tensor.h"
using namespace std;
convnet::Tensor::Tensor()
{
//dim.clear();
}
convnet::Tensor::Tensor(int a)
{
dim.resize(1);
dim[0] = a;
if(dim[0] > 0)
{
data.resize(a);
}
}
convnet::Tensor::Tensor(int a, int b)
{
dim.resize(2);
dim[0] = a;
dim[1] = b;
if(a*b > 0)
{
data.resize(a*b);
}
}
convnet::Tensor::Tensor(int a, int b, int c)
{
dim.resize(3);
dim[0] = a;
dim[1] = b;
dim[2] = c;
if(a*b*c > 0)
{
data.resize(a*b*c);
}
}
void convnet::Tensor::resize(int a)
{
dim.resize(1);
dim[0] = a;
if(dim[0] > 0 && a != data.size())
{
data.resize(a);
}
}
void convnet::Tensor::resize(int a, int b)
{
dim.resize(2);
dim[0] = a;
dim[1] = b;
if(a*b > 0 && a*b != data.size())
{
data.resize(a*b);
}
}
void convnet::Tensor::resize(int a, int b, int c)
{
if(dim.size() != 0)
{
dim.clear();
}
if(data.size() != 0)
{
data.clear();
}
dim.resize(3);
dim[0] = a;
dim[1] = b;
dim[2] = c;
if(a*b*c > 0)
{
data.resize(a*b*c);
}
}
convnet::Tensor::~Tensor()
{
dim.clear();
}
void convnet::Tensor::set(std::vector<Real> &data)
{
this->data = data;
}
int convnet::Tensor::size()
{
return data.size();
}
void convnet::Tensor::relu()
{
for(int i=0; i<data.size(); i++)
{
if(data[i] < 0)
{
data[i] = 0;
}
}
}
void convnet::Tensor::sigmoid()
{
for(int i=0; i<data.size(); i++)
{
data[i] = 1 / (1+expf(-data[i]));
}
}
void convnet::Tensor::argmax(int &s)
{
if(dim.size() == 1 && dim[0] > 0)
{
int u = 0;
int i = 1;
for(i=1; i<dim[0]; i++)
{
if(data[i] >= data[u])
{
u = i;
}
}
s = u;
}
}
void convnet::Tensor::tanh()
{
if(dim.size() == 1 && dim[0] > 0)
{
int i = 1;
for(i=0; i<dim[0]; i++)
{
data[i] = std::tanh(data[i]);
}
}
}
4.rnn.h
#ifndef _CONVNET_RNN_H_
#define _CONVNET_RNN_H_
#include <string>
#include "tensor.h"
namespace convnet {
class Rnn {
public:
Rnn(int m, int n);
~Rnn();
void set(Tensor *i, Tensor *h);
void setargs(std::vector<Real> &w1, std::vector<Real> &w2, std::vector<Real> &b1, std::vector<Real> &b2);
void forward();
void print(int type);
private:
int m;
int n;
Tensor sw;
Tensor sb;
Tensor ow;
Tensor ob;
Tensor ht;
Tensor st;
Tensor *input;
Tensor *hidden;
};
}
#endif
5.rnn.cpp
#include <cassert>
#include "rnn.h"
convnet::Rnn::Rnn(int m, int n)
{
this->m = m;
this->n = n;
sw.resize(m, n);
sb.resize(n);
ow.resize(n, n);
ob.resize(n);
st.resize(n);
ht.resize(n);
for(int s=0; s<n; s++)
{
ht.data[s] = 0.0;
}
}
convnet::Rnn::~Rnn()
{
}
void convnet::Rnn::set(Tensor *i, Tensor *h)
{
this->input = i;
this->hidden = h;
if(n != 0 && h->size() != n)
{
h->resize(n);
}
}
void convnet::Rnn::forward()
{
int i, j;
for(i=0; i<n; i++)
{
st.data[i] = 0.0;
for(j=0; j<m; j++)
{
st.data[i] += input->data[j] * sw.data[m*i + j];
}
st.data[i] += sb.data[i];
for(j=0; j<n; j++)
{
st.data[i] += ht.data[j] * ow.data[n*i + j];
}
st.data[i] += ob.data[i];
}
st.tanh();
for(i=0; i<n; i++)
{
ht.data[i] = st.data[i];
hidden->data[i] = st.data[i];
}
}
void convnet::Rnn::setargs(std::vector<Real> &w1, std::vector<Real> &w2, std::vector<Real> &b1, std::vector<Real> &b2)
{
assert(w1.size() == m * n && b1.size() == n);
assert(w2.size() == n * n && b2.size() == n);
sw.set(w1);
ow.set(w2);
sb.set(b1);
ob.set(b2);
}
void convnet::Rnn::print(int type)
{
if(type == 0)
{
for(int i=0; i<m; i++)
{
printf("%.6lf ", input->data[i]);
}
printf("\n");
}
else if(type == 1)
{
for(int i=0; i<n; i++)
{
printf("%.6lf ", hidden->data[i]);
}
printf("\n");
}
}
6.linear.h
#ifndef _CONVNET_LINEAR_H_
#define _CONVNET_LINEAR_H_
#include "tensor.h"
namespace convnet {
class Linear {
public:
Linear(int m, int n);
void set(Tensor *i, Tensor *o);
void forward();
void setargs(std::vector<Real> &w, std::vector<Real> &b);
void print(int type);
private:
int m;
int n;
Tensor weight;
Tensor bias;
Tensor *input;
Tensor *output;
};
}
#endif
7.linear.cpp
#include <stdio.h>
#include <cassert>
#include "linear.h"
convnet::Linear::Linear(int m, int n)
{
this->m = m;
this->n = n;
weight.resize(m, n);
bias.resize(n);
}
void convnet::Linear::set(Tensor *i, Tensor *o)
{
//printf("m = %d, %d\n", m, i->size());
if(m != i->size())
{
i->resize(m);
}
if(n != o->size())
{
o->resize(n);
}
input = i;
output = o;
}
void convnet::Linear::setargs(std::vector<Real> &w, std::vector<Real> &b)
{
assert(w.size() == m * n && b.size() == n);
{
weight.data = w;
bias.data = b;
}
}
void convnet::Linear::forward()
{
for(int out=0; out<n; out++)
{
output->data[out] = 0.0;
for(int in=0; in<m; in++)
{
output->data[out] += input->data[in] * weight.data[m*out + in];
}
output->data[out] += bias.data[out];
}
}
void convnet::Linear::print(int type)
{
if(type == 0)
{
for(int i=0; i<m; i++)
{
printf("%.6lf ", input->data[i]);
}
printf("\n");
}
else if(type == 1)
{
for(int i=0; i<n; i++)
{
printf("%.6lf ", output->data[i]);
}
printf("\n");
}
}
8.main.cpp
#include <iostream>
#include <vector>
#include "tensor.h"
#include "rnn.h"
#include "linear.h"
using namespace std;
using namespace convnet;
int main()
{
Tensor *input = new Tensor(5);
Tensor output[2];
Rnn rnn(5, 10);
Linear linear(10, 5);
vector<Real> w1 = {
1.0875e-01, 1.9007e-01, 3.3171e-02, -4.2167e-01, 2.3260e-01,
-4.0091e-02, 5.0844e-01, -4.9908e-03, -4.6057e-01, -3.2492e-01,
-4.4841e-01, -1.9543e-01, 4.1142e-01, 3.3704e-01, -7.6349e-02,
2.7550e-01, 1.1706e-01, 5.2413e-01, 3.6117e-01, -5.5991e-01,
3.9840e-01, 2.3983e-02, -3.0162e-01, -6.0204e-02, -1.3524e-01,
3.9064e-04, -4.0791e-01, 2.9194e-01, -3.2485e-01, 1.1633e-01,
2.1486e-01, 1.3768e-01, 3.2932e-01, -2.1038e-01, 3.9599e-01,
-2.4395e-01, -4.5648e-02, 3.3629e-01, -4.9821e-01, -9.9775e-02,
4.2604e-01, -6.1709e-01, 1.4171e-02, -3.5881e-01, -3.0136e-01,
-1.1071e-01, 4.3087e-01, -7.1492e-02, -4.4776e-02, -1.9427e-01
};
vector<Real> w2 = {
0.1804, -0.0511, -0.1194, -0.3126, 0.3056, 0.2208, 0.2536, -0.2775, 0.1334, -0.0084,
-0.0078, 0.0143, 0.0403, 0.1966, -0.0028, 0.0869, 0.0081, -0.2408, 0.0628, 0.0728,
0.0433, -0.2915, 0.2838, -0.1858, -0.0760, -0.2338, 0.0192, 0.2064, -0.0470, -0.2736,
-0.1543, -0.0061, -0.0271, 0.1564, -0.1332, 0.2041, -0.0063, -0.0483, 0.3013, -0.0242,
-0.0377, 0.1239, -0.1080, 0.2230, 0.1908, -0.2534, -0.2355, -0.2026, -0.0397, -0.0283,
0.3074, -0.1016, -0.2998, -0.2427, -0.0007, -0.1828, -0.0867, -0.2579, 0.2764, -0.1827,
-0.0062, 0.0415, -0.1900, 0.1646, -0.0817, 0.1933, -0.1867, -0.0074, -0.3107, -0.2211,
0.0158, 0.3108, -0.0322, 0.0481, 0.2690, 0.1093, -0.2631, 0.2370, -0.1548, -0.2132,
-0.2503, 0.2321, -0.0190, -0.2398, 0.1281, -0.2103, 0.3047, 0.3008, -0.2617, -0.1564,
0.0903, -0.2276, 0.1263, 0.0693, -0.2775, 0.2864, 0.1292, -0.3017, 0.1994, -0.1917
};
vector<Real> w3 = {
0.2785, -0.4658, -0.1571, -0.7094, -0.3439, 0.1966, 0.2905, -0.1365, -0.1963, -0.0616,
0.0765, 0.0545, -0.4994, 0.2200, 0.4290, 0.0951, -0.0870, -0.2526, 0.5923, -0.1664,
0.3581, 0.4772, -0.1968, 0.1624, -0.0177, -0.2860, 0.0372, 0.0018, -0.4826, 0.4231,
0.2152, 0.0899, 0.4478, 0.4280, -0.4368, 0.3140, 0.1704, 0.4542, 0.2319, 0.0220,
-0.2668, -0.4524, 0.2827, 0.3360, -0.0086, -0.2787, -0.1835, -0.5607, -0.2900, -0.1088
};
vector<Real> b1 = {0.2237, -0.1793, 0.1222, -0.0701, -0.1727, 0.0800, 0.2516, -0.3360, -0.0056, 0.0087};
vector<Real> b2 = {0.1713, -0.0580, -0.1222, 0.1427, 0.0238, -0.0174, 0.0815, 0.2246, 0.0837, 0.0762};
vector<Real> b3 = {-0.1141, 0.2620, 0.0653, 0.0189, 0.0603};
vector<Real> x1 = {1,0,0,0,0};
char chx1 = 0x00;
int y1;
printf("请输入测试字母: ");
scanf("%c", &chx1);
if(chx1 == 'a')
{
x1 = {1, 0, 0, 0, 0};
}
else if(chx1 == 'b')
{
x1 = {0, 1, 0, 0, 0};
}
else if(chx1 == 'c')
{
x1 = {0, 0, 1, 0, 0};
}
else if(chx1 == 'd')
{
x1 = {0, 0, 0, 1, 0};
}
else if(chx1 == 'e')
{
x1 = {0, 0, 0, 0, 1};
}
else
{
delete input;
return -1;
}
rnn.set(input, &output[0]);
linear.set(&output[0], &output[1]);
rnn.setargs(w1, w2, b1, b2);
linear.setargs(w3, b3);
input->set(x1);
rnn.forward();
linear.forward();
output[1].argmax(y1);
printf("预测字母为: %c\n", y1+'a');
delete input;
return 0;
}
9.Makefile
CXX=g++
STD=-std=c++11
DEBUG=-g
LDFLAGS=
CXXFLASG=
OBJS=tensor.o rnn.o linear.o
rnn: main.cpp $(OBJS)
$(CXX) $(DEBUG) -o rnn main.cpp $(OBJS) $(STD) $(LDFLAGS)
tensor.o: tensor.cpp tensor.h
$(CXX) $(DEBUG) -c tensor.cpp $(STD) $(CXXLFAGS)
rnn.o: rnn.cpp rnn.h
$(CXX) $(DEBUG) -c rnn.cpp $(STD) $(CXXLFAGS)
linear.o: linear.cpp linear.h
$(CXX) $(DEBUG) -c linear.cpp $(STD) $(CXXFLAGS)
clean:
rm -rf rnn
rm -rf $(OBJS)
10.编译源码
$ make
11.运行及其结果
$ ./rnn
请输入测试字母: a
预测字母为: b
$ ./rnn
请输入测试字母: b
预测字母为: c
$ ./rnn
请输入测试字母: c
预测字母为: d
$ ./rnn
请输入测试字母: d
预测字母为: e
$ ./rnn
请输入测试字母: e
预测字母为: a
网友评论