美文网首页工具癖MATLAB学习笔记个人专题
matlab学习笔记——DeepLearnToolBox搭建MN

matlab学习笔记——DeepLearnToolBox搭建MN

作者: dalalaa | 来源:发表于2018-10-21 19:56 被阅读34次

    DeepLearnToolBox是matlab下的一个简单的深度学习工具包,接口简单易用,其代码是纯matlab编写。

    使用过程非常简单,总共分两步:

    1. 在github上下载代码
    2. 打开matlab,在matlab命令行窗口中输入:addpath(genpath('所在文件夹\DeepLearnToolbox'));

    然后就可以愉快地敲代码了,下面是一个用于识别MNIST手写数字的官方示例:

    function test_example_CNN
    load mnist_uint8; % 加载手写数字
    
    % 处理数据
    train_x = double(reshape(train_x',28,28,60000))/255;
    test_x = double(reshape(test_x',28,28,10000))/255;
    train_y = double(train_y');
    test_y = double(test_y');
    %% 建立一个卷积神经网络
    % 跑一次循环需要200秒,一个epoch可以获得11%的误差;
    % 100 个epochs 之后可以获得1.2%的误差。
    rand('state',0)
    % 网络结构
    cnn.layers = {
        struct('type', 'i')  % 输入层
        struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5)  % 卷积层
        struct('type', 's', 'scale', 2)  % 上采样
        struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) % 卷积层
        struct('type', 's', 'scale', 2) % 上采样
    };
    % 网络初始化
    cnn = cnnsetup(cnn, train_x, train_y);
    
    % 参数
    opts.alpha = 1;
    opts.batchsize = 50;
    opts.numepochs = 1;
    
    % 训练
    cnn = cnntrain(cnn, train_x, train_y, opts);
    
    % 验证误差
    [er, bad] = cnntest(cnn, test_x, test_y);
    
    % 打印均方误差
    figure; plot(cnn.rL);
    
    % 如果er>=0.12 则报错
    assert(er<0.12, 'Too big error');
    

    相关文章

      网友评论

        本文标题:matlab学习笔记——DeepLearnToolBox搭建MN

        本文链接:https://www.haomeiwen.com/subject/vbcizftx.html