DeepLearnToolBox是matlab下的一个简单的深度学习工具包,接口简单易用,其代码是纯matlab编写。
使用过程非常简单,总共分两步:
- 在github上下载代码;
- 打开matlab,在matlab命令行窗口中输入:addpath(genpath('所在文件夹\DeepLearnToolbox'));
然后就可以愉快地敲代码了,下面是一个用于识别MNIST手写数字的官方示例:
function test_example_CNN
load mnist_uint8; % 加载手写数字
% 处理数据
train_x = double(reshape(train_x',28,28,60000))/255;
test_x = double(reshape(test_x',28,28,10000))/255;
train_y = double(train_y');
test_y = double(test_y');
%% 建立一个卷积神经网络
% 跑一次循环需要200秒,一个epoch可以获得11%的误差;
% 100 个epochs 之后可以获得1.2%的误差。
rand('state',0)
% 网络结构
cnn.layers = {
struct('type', 'i') % 输入层
struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5) % 卷积层
struct('type', 's', 'scale', 2) % 上采样
struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) % 卷积层
struct('type', 's', 'scale', 2) % 上采样
};
% 网络初始化
cnn = cnnsetup(cnn, train_x, train_y);
% 参数
opts.alpha = 1;
opts.batchsize = 50;
opts.numepochs = 1;
% 训练
cnn = cnntrain(cnn, train_x, train_y, opts);
% 验证误差
[er, bad] = cnntest(cnn, test_x, test_y);
% 打印均方误差
figure; plot(cnn.rL);
% 如果er>=0.12 则报错
assert(er<0.12, 'Too big error');
网友评论