美文网首页
开始深度学习之旅-使用mxnet训练mnist数据集

开始深度学习之旅-使用mxnet训练mnist数据集

作者: LonnieQ | 来源:发表于2019-12-31 20:48 被阅读0次

    源码链接

    https://github.com/LoniQin/AwsomeNeuralNetworks/blob/master/trainer_v1.py

    1. 导入依赖库

    from mxnet.gluon import data as gdata
    import time
    from mxnet import autograd, nd
    

    2. 定义常量

    如下,num_inputs指的是输入参数的数量,num_outputs指的是输出参数的数量,batch_size是每次训练的批次数,num_epochs是训练的周期数,learning_rate是学习率。

    num_inputs = 784
    num_outputs = 10
    batch_size = 256
    num_workers = 4
    num_epochs = 100
    learning_rate = 0.1
    

    3. 获取数据

    mnist_train, mnist_test分别是mnist的训练集和测试集,首先会检测本地是否有mnist数据集,如果存在则加载本地;否则从网络下载保存到本地使用。train_iter和test_iter可以供给用户用for循环迭代进行训练。

    mnist_train = gdata.vision.FashionMNIST(train=True)
    
    mnist_test = gdata.vision.FashionMNIST(train=False)
    
    transformer = gdata.vision.transforms.ToTensor()
    
    train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True, num_workers=num_workers)
    
    test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False, num_workers=num_workers)
    

    4. 定义分类函数

    在图像分类的过程中,假设有n类,那么输出需要有n个属性。我们要将这n个属性分别计算出概率并且选取最大可能性的类别。Softmax函数很好地处理这个问题。假设某一个数据集有n个类别,经过线性回归或其他方式计算出的输出值为o_1o_n, softmax的计算公式为:
    softmax(o_i) = \frac{exp(o_i)}{\sum_{i=1}^n exp(o_i)}
    很显然,它的输出个数等于标签的类别数,它会计算出每个类别的概率,并且它们的和为1。

    python代码为:

    def softmax(X):
        x_exp = X.exp()
        partition = x_exp.sum(axis = 1, keepdims = True)
        return x_exp / partition
    

    5. 定义损失函数

    损失函数用于衡量神经网络的性能。对于分类问题,交叉熵是常用的损失函数,可以衡量两个概率的分布差异。其计算公式为:
    H(y^{(i)}, \hat{y}^{(i)}) = -\sum_{j=1}^qy_j^{(i)}logy_j^{(i)}
    若训练数据集为n, 交叉熵损失函数定义为:
    l(\theta) = \frac{1}{n}\sum_{i=1}^nH(y^{(i)}, \hat{y}^{(i)})

    def cross_entropy(y_hat, y):
        return  -nd.pick(y_hat, y).log()
    

    6 定义参数更新函数

    在训练图像数据的过程中,我们首先会随机初始化一组权重数据W和偏差数据b,我们用这些数据通过神经网络计算出结果,跟训练数据的结果用损失函数进行比较,会得到一个差值。然后我们会用反向传播算法计算出W和b的梯度。然后用某种策略更新W和b的数据,使之在下一次训练数据中得到更接近答案的数据。随机梯度下降函数就有这样的作用,它会根据我们设置的学习率更新参数,其范围从0到1。
    实现代码如下:

    def sgd(params, lr, batch_size):
        for i in range(len(params)):
            nd.elemwise_sub(params[i], lr * params[i].grad / batch_size, out=params[I])
    

    我们需要选择合适的学习率,如果我们学习率选的太小,那么需要很久才能够接近最优解;如果我们的学习率选的太大,那么很容易错过最优解。

    7. 构建神经网络

    我们通常用线性回归来计算分类问题,其表达式为:
    y = XW + b
    代码实现为:

    def net(X, w, b):
        y = nd.dot(X.reshape((-1, w.shape[0])), w) + b
        return softmax(y)
    

    这只是单层神经网络。在实际的训练过程中,我们需要构建几层甚至几百层神经网络。提高神经网络准确率的一个方向便是建立更深层的神经网络。

    8. 训练数据

    在这个例子中,我们首先初始化权重w和偏差参数b,并且开启记录梯度的功能,这里我们每批次训练batch_size256次,训练数据集个数为50000条,每个周期我们会进行196批的训练。我们这里的总周期是num_epochs = 100次。在每一批次的训练过程中,我们调用神经网络计算出y的预测值\hat{y},然后用损失函数计算出损失值l, 然后l进行反向传播,wb会因此得到梯度值,最终我们调用参数更新函数更新wb参数。

    w = nd.random.normal(scale=1.0, shape=(num_inputs, num_outputs))
    b = nd.zeros(num_outputs)
    w.attach_grad()
    b.attach_grad()
    loss = cross_entropy
    # Train models
    start = time.time()
    for epoch in range(1, num_epochs + 1):
        acc_sum, n = 0.0, 0
        total = float(len(mnist_train))
        for X, y in train_iter:
            with autograd.record():
                y_hat = net(X, w, b)
                l = loss(y_hat, y).sum()
            l.backward()
            sgd([w, b], learning_rate, batch_size)
            acc_sum += (y_hat.argmax(axis=1) == y.astype('float32')).sum().asscalar()
        print("Epoch:%d Elapsed time:%.2f accuracy:%.2f%%" % (epoch, time.time() - start, (acc_sum / total) * 100))
    

    9. 运行结果

    经过100周期的训练,预测准确率从44%到83%.

    Epoch:1 Elapsed time:1.37 accuracy:44.04%
    Epoch:2 Elapsed time:2.69 accuracy:62.63%
    Epoch:3 Elapsed time:4.00 accuracy:67.57%
    Epoch:4 Elapsed time:5.32 accuracy:70.34%
    Epoch:5 Elapsed time:6.63 accuracy:72.17%
    Epoch:6 Elapsed time:8.09 accuracy:73.39%
    Epoch:7 Elapsed time:9.36 accuracy:74.33%
    Epoch:8 Elapsed time:10.63 accuracy:75.06%
    Epoch:9 Elapsed time:11.93 accuracy:75.83%
    Epoch:10 Elapsed time:13.20 accuracy:76.27%
    Epoch:11 Elapsed time:14.59 accuracy:76.70%
    Epoch:12 Elapsed time:15.89 accuracy:77.01%
    Epoch:13 Elapsed time:17.22 accuracy:77.41%
    Epoch:14 Elapsed time:18.59 accuracy:77.81%
    Epoch:15 Elapsed time:19.93 accuracy:78.02%
    Epoch:16 Elapsed time:21.31 accuracy:78.27%
    Epoch:17 Elapsed time:22.69 accuracy:78.50%
    Epoch:18 Elapsed time:24.12 accuracy:78.70%
    Epoch:19 Elapsed time:25.46 accuracy:78.91%
    Epoch:20 Elapsed time:26.82 accuracy:78.98%
    Epoch:21 Elapsed time:28.27 accuracy:79.24%
    Epoch:22 Elapsed time:29.70 accuracy:79.34%
    Epoch:23 Elapsed time:31.25 accuracy:79.60%
    Epoch:24 Elapsed time:32.62 accuracy:79.64%
    Epoch:25 Elapsed time:34.35 accuracy:79.73%
    Epoch:26 Elapsed time:35.84 accuracy:79.92%
    Epoch:27 Elapsed time:37.32 accuracy:80.06%
    Epoch:28 Elapsed time:38.97 accuracy:80.14%
    Epoch:29 Elapsed time:40.59 accuracy:80.23%
    Epoch:30 Elapsed time:42.28 accuracy:80.45%
    Epoch:31 Elapsed time:43.77 accuracy:80.52%
    Epoch:32 Elapsed time:45.43 accuracy:80.58%
    Epoch:33 Elapsed time:47.05 accuracy:80.69%
    Epoch:34 Elapsed time:48.70 accuracy:80.80%
    Epoch:35 Elapsed time:50.20 accuracy:80.86%
    Epoch:36 Elapsed time:51.62 accuracy:80.90%
    Epoch:37 Elapsed time:53.17 accuracy:81.02%
    Epoch:38 Elapsed time:54.65 accuracy:81.08%
    Epoch:39 Elapsed time:56.22 accuracy:81.21%
    Epoch:40 Elapsed time:57.79 accuracy:81.25%
    Epoch:41 Elapsed time:59.29 accuracy:81.34%
    Epoch:42 Elapsed time:60.87 accuracy:81.34%
    Epoch:43 Elapsed time:62.45 accuracy:81.55%
    Epoch:44 Elapsed time:64.04 accuracy:81.73%
    Epoch:45 Elapsed time:65.41 accuracy:81.56%
    Epoch:46 Elapsed time:66.79 accuracy:81.73%
    Epoch:47 Elapsed time:68.23 accuracy:81.68%
    Epoch:48 Elapsed time:69.78 accuracy:81.74%
    Epoch:49 Elapsed time:71.24 accuracy:81.91%
    Epoch:50 Elapsed time:72.76 accuracy:81.87%
    Epoch:51 Elapsed time:74.37 accuracy:81.91%
    Epoch:52 Elapsed time:76.27 accuracy:82.09%
    Epoch:53 Elapsed time:78.28 accuracy:82.04%
    Epoch:54 Elapsed time:80.26 accuracy:82.15%
    Epoch:55 Elapsed time:82.13 accuracy:82.20%
    Epoch:56 Elapsed time:83.70 accuracy:82.29%
    Epoch:57 Elapsed time:85.53 accuracy:82.27%
    Epoch:58 Elapsed time:87.22 accuracy:82.38%
    Epoch:59 Elapsed time:88.72 accuracy:82.36%
    Epoch:60 Elapsed time:90.45 accuracy:82.47%
    Epoch:61 Elapsed time:92.20 accuracy:82.39%
    Epoch:62 Elapsed time:93.97 accuracy:82.48%
    Epoch:63 Elapsed time:95.59 accuracy:82.50%
    Epoch:64 Elapsed time:97.21 accuracy:82.61%
    Epoch:65 Elapsed time:98.69 accuracy:82.65%
    Epoch:66 Elapsed time:100.37 accuracy:82.78%
    Epoch:67 Elapsed time:102.10 accuracy:82.74%
    Epoch:68 Elapsed time:103.85 accuracy:82.77%
    Epoch:69 Elapsed time:105.74 accuracy:82.79%
    Epoch:70 Elapsed time:107.22 accuracy:82.91%
    Epoch:71 Elapsed time:108.85 accuracy:82.94%
    Epoch:72 Elapsed time:110.55 accuracy:83.04%
    Epoch:73 Elapsed time:112.52 accuracy:83.05%
    Epoch:74 Elapsed time:114.06 accuracy:83.08%
    Epoch:75 Elapsed time:115.77 accuracy:83.09%
    Epoch:76 Elapsed time:117.37 accuracy:83.16%
    Epoch:77 Elapsed time:119.14 accuracy:83.26%
    Epoch:78 Elapsed time:120.74 accuracy:83.20%
    Epoch:79 Elapsed time:122.41 accuracy:83.23%
    Epoch:80 Elapsed time:123.95 accuracy:83.23%
    Epoch:81 Elapsed time:125.73 accuracy:83.24%
    Epoch:82 Elapsed time:127.32 accuracy:83.31%
    Epoch:83 Elapsed time:128.98 accuracy:83.36%
    Epoch:84 Elapsed time:130.51 accuracy:83.39%
    Epoch:85 Elapsed time:132.11 accuracy:83.47%
    Epoch:86 Elapsed time:133.66 accuracy:83.47%
    Epoch:87 Elapsed time:135.29 accuracy:83.48%
    Epoch:88 Elapsed time:136.92 accuracy:83.55%
    Epoch:89 Elapsed time:138.61 accuracy:83.50%
    Epoch:90 Elapsed time:140.20 accuracy:83.60%
    Epoch:91 Elapsed time:141.80 accuracy:83.58%
    Epoch:92 Elapsed time:143.51 accuracy:83.66%
    Epoch:93 Elapsed time:145.01 accuracy:83.63%
    Epoch:94 Elapsed time:146.40 accuracy:83.66%
    Epoch:95 Elapsed time:147.87 accuracy:83.71%
    Epoch:96 Elapsed time:149.47 accuracy:83.79%
    Epoch:97 Elapsed time:151.07 accuracy:83.79%
    Epoch:98 Elapsed time:152.62 accuracy:83.84%
    Epoch:99 Elapsed time:154.46 accuracy:83.79%
    Epoch:100 Elapsed time:156.07 accuracy:83.85%
    

    这个结果令人振奋。如果优化后,我们将取得更高的准确率。优化的途径如下:

    • 构建更深层的神经网络。
    • 添加更多的数据样本。
    • 选择合适的参数更新方法。除了随机梯度下降法外,还有动量法、AdaGrad算法、RMSProp算法、AdaDelta算法、Adam算法等。
    • 采用权重衰减和Dropout法预防过拟合现象。
    • 采用CNN优化图像识别。
    • 优化权重参数和偏差参数的生成方式。
    • 优化损失函数。
    • 将训练的最终结果或最优结果保存起来,用于未来的训练。
    • 优化计算性能,采用异步计算、自动并行计算、多GPU计算等方式加快计算速度。
    • 选择合理的超参数,超参数有神经网络层次数量、每一层使用的激活函数、每层的参数个数、批次大小、周期数等。我们可以构建服务器进行暴力搜索的方式发现性能更好的超参数。

    相关文章

      网友评论

          本文标题:开始深度学习之旅-使用mxnet训练mnist数据集

          本文链接:https://www.haomeiwen.com/subject/jkjioctx.html