美文网首页
numpy实现一个神经网络识别手写数字数据集

numpy实现一个神经网络识别手写数字数据集

作者: doinb1517 | 来源:发表于2021-12-08 18:59 被阅读0次

    前言

    机器学习在安全行业的应用非常广,这就要求我们在深耕自己细分领域的同时还应该广泛涉猎;对机器学习相关基础知识,数据分析基本概念有所掌握。提起机器学习,大家想到的可能就是各大知名框架,大家对框架的选择也都各有所爱,得益于框架的良好封装,自己快速搭建一个机器学习网络并不是什么难事。使用框架进行编程的第一句都免不了import某个框架,使用过程变成了无聊的“调参侠”(仅对低级使用者而言,专业的算法工程师还是有很强的业务能力),这就导致我们常常会忽视算法的本质,本次使用numpy实现一个ANN来帮助理解机器学习相关概念,一个框架所有的部分都可以使用numpy实现。真正完成一次实验会对机器学习的本质有更加深入的认识。

    背景

    本次使用最简单的MNIST进行实验,使用numpy实现一个可以识别手写数字集的人工神经网络。

    Mnist:大多数示例使用手写数字的MNIST数据集[1]。该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。为简单起见,每个图像都被平展并转换为784(28 * 28)个特征的一维numpy数组。

    首先使用简单的keras框架构建一个网络进行手写数字的识别,网上相关资料很多,不再赘述,这里给出一份代码。

    import numpy as np
    import keras
    import pandas as pd
    from keras import layers
    from matplotlib import pyplot as plt
    from keras.datasets import mnist as mn
    
    %matplotlib inline
    
    # 读取训练数据和测试数据
    (train_img, train_lab), (test_img, test_lab) = mn.load_data()
    model = keras.Sequential()
    model.add(layers.Flatten()) # (60000, 28, 28) => (60000, 28*28)
    model.add(layers.Dense(64, activation='tanh'))
    model.add(layers.Dense(10, activation='softmax'))
    
    # 编译模型
    model.compile(
        optimizer="adam",
        # 注意因为label是顺序编码的,这里用这个
        loss='sparse_categorical_crossentropy',
        metrics = ['accuracy']
    )
    
    # 模型结构
    model.summary()
    
    # 使用history保存每个epoch结束的loss,accuracy等信息
    history = model.fit(train_img, train_lab, epochs=10, batch_size=500, validation_data=(test_img, test_lab), verbose=2) # 每批500张图片
    
    # 保存模型
    model.save('keras_mnist.h5')
    

    可视化训练过程,使用history保存的信息画出折线图。

    plt.plot(history.history['val_accuracy'], c='g', label='validation acc')
    plt.plot(history.history['accuracy'], c='b', label='train acc')
    plt.legend()
    plt.show()
    
    
    train_acc.png
    plt.plot(history.history['val_loss'], c='g', label='validation loss')
    plt.plot(history.history['loss'], c='b', label='train loss')
    plt.legend()
    plt.show()
    
    train_loss.png

    使用模型进行预测

    # 加载训练的模型
    from keras.models import load_model
    model = load_model("model_name.h5")
    
    result = model.predict(test_img)
    def show_test(index):
        plt.imshow(test_img[index],cmap='gray')
        print("label : {}".format(test_lab[index]))
        print("predict : {}".format(result[index].argmax()))
        
    index = np.random.randint(1, len(test_img))
    show_test(index)
    
    keras_predict.png

    上面部分代码就是使用keras实现的过程,非常简单。接下来进入主题,可以对比一下手工实现和使用框架实现的区别。

    首先需要明确的输入输出的维度,输入维度很简单,像素是28*28的,我们把每行的数据拼接起来,一张图片的维度就是28*28=784维的向量,输出的是0-9的10维向量,

    输入: 784  
    输出: 10  
    

    由此我们构建一个最简单的只有一个隐藏层的神经网络。神经元之间采用最简单的线性连接。下面的公式就是一个最简单的神经网络,后续的全部工作就是实现这两个公式。

    data是输入的图片,维度是[1,784];output是输出的预测结果维度是[1,10];A和B都是激活函数;h是隐藏层;
    \vec{h} = A(data + b_0)

    \vec{output} = B(\vec{h}w_1 + b_1)

    根据输入输出确定其他参数的维度,只有参数维度不出问题才能确保下面的流程正确进行

    b0和data同形,b0维度也是[1,784],

    所以h也是[1,784]维

    b1和output同形,b1维度也是[1,10]

    根据矩阵乘法w1是[784,10]

    将上面分析的结果带入原公式中,确认分析没毛病。

    data:784
    output:10
    b_0:784
    h:784
    w_1:[784, 10]
    b_1:10
    
    [1,784] = [1,784] + [1,784]
    [1,10] = [1,784][784,10] + [1,10]
    

    确认参数维度之后我们开始初始化公式中的参数。

    导入必要的包

    import math
    import copy
    import numpy as np
    import matplotlib.pyplot as plt
    
    # 定义参数的维度
    dimensions=[28*28,10]
    activation=[tanh,softmax]
    distribution=[
        {'b':[0,0]},
        {'b':[0,0],'w':[-1,1]},
    ]
    #实现初始化参数
    def init_parameters_b(layer):
        dist=distribution[layer]['b']
        return np.random.rand(dimensions[layer])*(dist[1]-dist[0])+dist[0]
    def init_parameters_w(layer):
        dist=distribution[layer]['w']
        return np.random.rand(dimensions[layer-1],dimensions[layer])*(dist[1]-dist[0])+dist[0]
    def init_parameters():
        parameter=[]
        for i in range(len(distribution)):
            layer_parameter={}
            for j in distribution[i].keys():
                if j=='b':
                    layer_parameter['b']=init_parameters_b(i)
                    continue
                if j=='w':
                    layer_parameter['w']=init_parameters_w(i)
                    continue
            parameter.append(layer_parameter)
        return parameter
    parameters=init_parameters()
    

    这样parameters就是我们初始化成功的参数,验证参数生成是否正确

    #测试参数生成
    import tensorflow as tf
    print(tf.shape(parameters[0]['b']))
    print(tf.shape(parameters[1]['b']))
    print(tf.shape(parameters[1]['w']))
    parameters
    

    输出如下,可以看到输出和我们的预期一致。训练神经网络模型的过程实际上就是找到合适的参数的过程,等训练结束之后可以比较一下parameters的变化。

    tf.Tensor([784], shape=(1,), dtype=int32)
    tf.Tensor([10], shape=(1,), dtype=int32)
    tf.Tensor([784  10], shape=(2,), dtype=int32)
    
    [{'b': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0.])},
     {'b': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
      'w': array([[-0.11063313,  0.01986579, -0.54817987, ...,  0.06447563,
               0.55723463, -0.36988999],
             [ 0.93507145, -0.13798472, -0.68732584, ...,  0.04820882,
              -0.33476673, -0.69842804],
             [-0.22759392, -0.61509861, -0.93002526, ..., -0.45658224,
              -0.4769593 , -0.68456901],
             ...,
             [ 0.28200054, -0.92222148, -0.16388762, ...,  0.95929227,
               0.60109395,  0.84298182],
             [ 0.22296947, -0.4467861 , -0.65828542, ..., -0.10003993,
               0.29943921, -0.46707877],
             [ 0.01722276,  0.04571887, -0.87339843, ..., -0.03931738,
              -0.36247935,  0.61174093]])}]
    

    下一步我们实现公式中的两个激活函数A和B

    这里A函数使用tanh()做激活函数,这里也可以采用其他函数,因为本次需要手动计算梯度等操作,所以选择一个比较简单的函数方便计算。B函数softmax()用作分类

    #定义需要的两个激活函数
    def tanh(x):
        return np.tanh(x)
    def softmax(x):
        exp=np.exp(x-x.max())
        return exp/exp.sum()
    

    这里exp=np.exp(x-x.max())的操作是为了防止指数爆炸引起上溢,如果计算np.exp(1000)会报RuntimeWarning: overflow encountered in exp的错误,将每一位数都减去这组数中最大的那一位,是不会对结果产生影响的(高中知识)

    计算一组测试数据softmax(np.array([1,2,3])),和softmax(np.array([-2,-1,0]))运行结果都如下

    array([0.09003057, 0.24472847, 0.66524096])
    

    接下来定义预测函数,就是output = B(h w_1 + b_1)这部分,给定一个784的输入,给出一个10维的输出。下面的简单代码就构建完成了一个神经网络,接下来我们要做的就是利用梯度下降更新parameters,让这个神经网络表现出更好的分类效果。

    #定义预测函数
    def predict(img,parameters):
        l0_in=img+parameters[0]['b']
        l0_out=activation[0](l0_in)
        l1_in=np.dot(l0_out,parameters[1]['w'])+parameters[1]['b']
        l1_out=activation[1](l1_in)
        return l1_out
    

    运行预测函数输出如下:

    # 测试预测函数
    # predict(np.random.rand(784),parameters)
    
    array([1.31779086e-06, 2.43796927e-07, 2.34876198e-10, 4.46492815e-02,
           5.90359130e-02, 8.78048316e-02, 1.67109396e-02, 2.06152860e-10,
           7.91797467e-01, 5.26289356e-09])
    

    读取数据并展示数据

    # 读取数据集
    from pathlib import Path
    import struct
    dataset_path=Path('./MNIST')
    train_img_path=dataset_path/'train-images.idx3-ubyte'
    train_lab_path=dataset_path/'train-labels.idx1-ubyte'
    test_img_path=dataset_path/'t10k-images.idx3-ubyte'
    test_lab_path=dataset_path/'t10k-labels.idx1-ubyte'
    
    # 5w训练集,1w验证集,1w测试数据集
    train_num=50000
    valid_num=10000
    test_num=10000
    
    with open(train_img_path,'rb') as f:
        struct.unpack('>4i',f.read(16))
        tmp_img=np.fromfile(f,dtype=np.uint8).reshape(-1,28*28)/255
        train_img=tmp_img[:train_num]
        valid_img=tmp_img[train_num:]
        
    with open(test_img_path,'rb') as f:
        struct.unpack('>4i',f.read(16))
        test_img=np.fromfile(f,dtype=np.uint8).reshape(-1,28*28)/255
    
    with open(train_lab_path,'rb') as f:
        struct.unpack('>2i',f.read(8))
        tmp_lab=np.fromfile(f,dtype=np.uint8)
        train_lab=tmp_lab[:train_num]
        valid_lab=tmp_lab[train_num:]
        
    with open(test_lab_path,'rb') as f:
        struct.unpack('>2i',f.read(8))
        test_lab=np.fromfile(f,dtype=np.uint8)
    
    def show_train(index):
        plt.imshow(train_img[index].reshape(28,28),cmap='gray')
        print('label : {}'.format(train_lab[index]))
    def show_valid(index):
        plt.imshow(valid_img[index].reshape(28,28),cmap='gray')
        print('label : {}'.format(valid_lab[index]))
    def show_test(index):
        plt.imshow(test_img[index].reshape(28,28),cmap='gray')
        print('label : {}'.format(test_lab[index]))
    

    读取数据和测试数据的部分就完成了,测试数据集中的数据是不会在训练集中出现的,这就好像把高考原题给你去训练,有可能你是仅仅背会了这个题目而不是真正的理解了如何去分析解题,所以高考基本不会出现原题,这样才能考察出你对知识的理解掌握程度。

    # 测试show_train
    show_train(np.random.randint(train_num))
    
    test.png

    接下来就是更新参数的部分,参数我们都是随机生成的,自然神经网络给出的预测结果也是随机的,我们如何去评价预测的结果和真实结果的差异呢,这种评价机制就是损失函数。理想状态下神经网络给出的预测应该是这样的:比如这个图片真实的标签是1,神经网络预测这个图片是1的概率为100%,为其他数字的概率都是0。然而显示中神经网络达不到这样的效果,更多时候都是“人工智障”,我们将所有预测中概率最大的那个作为神经网络的预测结果。

    # 定义损失函数
    onehot=np.identity(10)
    def sqr_loss(img, lab, parameters):
        y_pred=predict(img,parameters)
        y=onehot[lab]
        diff=y-y_pred
        return np.dot(diff,diff)
    

    定义完损失函数之后,我们只需要让损失函数持续的减少,就能向理想化的预测结果不断靠近,究竟如何进行参数的更新,这里就要使用梯度,大学高数告诉我们梯度方向是上升最快的方向,负梯度方向为下降最快的方向,通过往梯度(gradient)下降的方向调整参数,逐步减小损失函数loss function的值,从而得到训练好的模型。

    #定义两个激活函数的导数
    def d_softmax(data):
        sm=softmax(data)
        return np.diag(sm)-np.outer(sm,sm)
    
    def d_tanh(data):
        return 1/(np.cosh(data))**2
    
    differential={softmax:d_softmax,tanh:d_tanh}
    
    def grad_parameters(img,lab,parameters):
        l0_in=img+parameters[0]['b']
        l0_out=activation[0](l0_in)
        l1_in=np.dot(l0_out,parameters[1]['w'])+parameters[1]['b']
        l1_out=activation[1](l1_in)
        
        diff=onehot[lab]-l1_out
        act1=np.dot(differential[activation[1]](l1_in),diff)
        
        grad_b1=-2*act1
        grad_w1=-2*np.outer(l0_out,act1)
        grad_b0=-2*differential[activation[0]](l0_in)*np.dot(parameters[1]['w'],act1)
        
        return {'w1':grad_w1,'b1':grad_b1,'b0':grad_b0}
    

    输入图片和标签和初始化的参数后,就会计算梯度

    # 测试梯度计算
    # grad_parameters(train_img[2],train_lab[2],init_parameters())
    
    {'w1': array([[-0., -0., -0., ..., -0., -0., -0.],
            [-0., -0., -0., ..., -0., -0., -0.],
            [-0., -0., -0., ..., -0., -0., -0.],
            ...,
            [-0., -0., -0., ..., -0., -0., -0.],
            [-0., -0., -0., ..., -0., -0., -0.],
            [-0., -0., -0., ..., -0., -0., -0.]]),
     'b1': array([-8.97529566e-04, -1.57305145e-04, -1.13258721e-05,  1.51856573e-02,
            -3.34573808e-03, -8.80455733e-04, -6.32826250e-03, -3.27152101e-03,
            -2.34009828e-04, -5.95096104e-05]),
     'b0': array([-8.35839128e-03, -5.34125889e-03, -1.37208082e-02,  5.61553159e-03,
             1.95620702e-02,  4.82436064e-03, -1.81148739e-02,  7.06783791e-03,
             6.77894581e-03, -6.20645121e-05, -1.03271620e-02,  8.64674933e-03,
             1.20581038e-02, -1.66106477e-02,  1.42724550e-03, -2.10642636e-03,
             5.17487744e-03,  6.77834056e-03,  5.23595222e-03,  9.60313306e-03,
             5.52857690e-03,  1.52551626e-02,  1.26690528e-02, -9.09700995e-03,
            -7.05989943e-03, -1.50500067e-03,  1.54893607e-02,  9.42818548e-03,
            -6.14107702e-03, -5.62287700e-03, -1.77016958e-02, -1.08695394e-02,
            -4.73702194e-03,  9.69312031e-04,  3.38897299e-05,  9.68205152e-03,
             1.40352488e-02,  1.28429731e-02,  8.39480940e-04,  1.05712592e-02,
            -1.73710289e-02, -9.64789484e-03,  7.18972793e-03,  1.03710320e-02,
            -3.68373317e-03,  1.56242949e-03,  4.60410099e-03, -1.41165218e-02,
            -1.05017685e-02, -1.69197853e-02,  1.14012162e-02, -3.17723825e-03,
             5.86251639e-03, -4.55389491e-03, -6.27766397e-04, -2.45242403e-03,
             3.32728074e-03,  9.66014811e-03,  8.54219077e-03, -2.63156769e-03,
             7.90602741e-03,  5.91128736e-04, -4.63103323e-03,  7.27767514e-03,
             2.20731150e-03,  1.65127510e-02,  5.15687997e-03, -4.42550832e-04,
             1.57667838e-03,  4.46071508e-03,  5.84473192e-03,  9.07912645e-03,
             2.09107272e-02,  2.36699460e-02,  1.28038267e-02, -1.19544611e-02,
             1.16261562e-02, -1.62204299e-02,  3.32573066e-04,  1.43708779e-02,
            -1.15140572e-02, -1.62038995e-02, -6.64056214e-03, -2.68635750e-03,
            -4.15420403e-03,  1.43767043e-02, -3.52162151e-03,  7.34874136e-03,
            -1.50134966e-02, -1.85408493e-03, -5.97253728e-03, -7.70481837e-03,
             1.85042403e-03,  1.69247570e-03,  4.48173562e-03,  3.94761096e-03,
            -6.09264268e-03,  1.84609092e-02,  9.15935735e-03, -1.60288369e-02,
             2.47151419e-03, -1.75132583e-03,  2.15554612e-02, -1.08568207e-02,
            -5.38923467e-03,  1.63112882e-02,  1.36774969e-03, -1.01017574e-03,
            -9.67172656e-03, -2.05714170e-02,  1.31786779e-02, -1.20454389e-02,
             9.62068604e-03,  1.20974865e-02,  1.74756972e-02,  3.63575259e-03,
            -1.33278364e-02, -8.15359725e-03,  3.03751796e-03, -8.62243956e-03,
            -3.93034010e-03, -7.80119830e-03,  4.28420827e-03,  3.99100767e-03,
            -6.96652322e-03,  9.98542524e-03,  1.32837396e-02, -1.64502311e-02,
            -2.45694320e-03,  2.86176292e-03, -4.70804440e-03,  7.45219664e-03,
            -7.93939346e-03,  9.73363380e-04, -4.65506787e-03,  5.09916398e-03,
            -3.79960523e-03,  2.49196922e-02, -3.19170301e-03, -8.78287071e-04,
             8.60141093e-03, -5.18163659e-03, -7.67823469e-03,  9.12309830e-03,
             5.69992325e-03, -9.61483330e-03,  5.34989829e-04,  1.34362523e-03,
             1.48173676e-02,  1.01570384e-02, -4.22456962e-03, -6.73901076e-03,
             1.12639146e-02, -4.97857054e-04,  7.50129527e-03, -1.62759133e-04,
             2.66076235e-03, -8.20197284e-03,  9.08791485e-03, -1.17750589e-02,
             4.10305929e-03,  4.44696391e-03,  7.20534143e-03,  1.27302774e-02,
             2.09109442e-02, -1.17840327e-02,  1.68535551e-02, -5.45905767e-03,
             1.51291239e-02, -8.07673020e-03,  9.48311248e-03, -9.34733696e-05,
             1.01434332e-02, -3.11975379e-03,  5.63583825e-03,  5.75590337e-03,
             1.14765621e-02, -1.65725199e-02, -1.81781627e-02, -4.29306537e-03,
             8.29402330e-03,  6.48925555e-03,  9.36958945e-03, -4.97385263e-03,
             3.27683831e-04,  7.87043527e-03, -1.19547598e-02, -1.34038752e-02,
             6.52915908e-03,  1.44435525e-03,  8.83738335e-03,  7.74306472e-03,
            -1.00514819e-02,  5.32062237e-03,  8.15320109e-03,  1.27634258e-02,
            -1.06853142e-02, -2.37312411e-03, -1.46285969e-02,  1.33049237e-02,
             1.30542203e-02, -2.52723104e-03,  8.27540521e-03, -1.02999232e-03,
            -1.44753332e-02,  5.52333176e-03,  6.69350221e-03, -1.58433410e-02,
             7.49662163e-03, -7.04779382e-03, -1.22842597e-02, -1.84830282e-03,
             1.39855218e-02,  6.31352913e-03, -1.48297463e-02,  3.38934718e-03,
            -6.62877874e-03,  5.75719349e-03,  1.44137908e-02,  1.95541441e-02,
             8.10794967e-03, -4.58825637e-03, -6.31258088e-03,  8.76559737e-03,
            -1.49603323e-02, -9.71335649e-03, -3.96255520e-03, -2.12043313e-03,
            -3.12815557e-03,  5.99221757e-03,  5.43202946e-03, -9.03813632e-04,
            -9.52207459e-03, -1.38406399e-02,  1.63812992e-02,  1.43130646e-02,
            -6.26658585e-03, -1.84001540e-02,  8.79091617e-03,  9.12318712e-03,
            -3.84151764e-03,  7.49930864e-03, -2.33399185e-03, -1.34922235e-02,
            -1.20013433e-03, -2.45301294e-03,  1.87235962e-02,  6.47738169e-03,
             9.04561565e-03, -1.78941246e-03, -2.89415340e-03,  2.10333201e-03,
            -6.02399020e-04, -4.35288849e-03,  1.29300391e-02, -3.85006484e-03,
             7.31630507e-04, -6.38921249e-03, -5.93935905e-05, -1.51483706e-02,
            -1.28440162e-03, -1.43829370e-02, -3.33054898e-03, -2.97546136e-03,
            -6.58845675e-03, -6.52283492e-03, -7.70284688e-03, -1.36712024e-02,
            -8.54573417e-03, -9.90562153e-03, -6.22076017e-03, -9.82411806e-03,
             1.45459409e-03,  7.68754545e-03,  1.32780017e-02,  9.00071756e-04,
            -4.74501190e-03,  4.00242229e-03, -3.53622701e-03,  7.31449850e-03,
            -9.75946834e-03,  7.65593601e-03,  6.07010229e-03, -1.72909057e-03,
            -1.69049170e-03, -2.10693899e-03, -7.90723651e-03, -7.45393514e-03,
             4.24886363e-03,  7.32627994e-03,  5.10599485e-03, -1.76507662e-02,
            -5.49388863e-03,  1.46263929e-02,  1.18535559e-02,  1.26915167e-03,
             6.46282523e-04, -1.42360909e-02,  7.37329855e-03, -6.26288989e-05,
             2.43055218e-03,  8.92408413e-03,  1.24414497e-02,  6.55988247e-03,
             9.37842213e-03,  7.08007913e-04, -7.35278334e-03,  5.28320647e-03,
             8.88771246e-03,  1.39800623e-02,  1.26165278e-02,  4.43142767e-03,
             3.47184004e-03, -1.04267387e-02,  1.03648557e-02,  5.17463675e-03,
             1.11589419e-02,  1.46939647e-02,  1.61697670e-02, -1.12996099e-02,
             6.94902080e-03,  1.48351031e-02, -2.95375616e-04, -1.86238982e-03,
            -1.02235549e-02,  5.61529211e-03,  5.47667190e-03, -1.27005573e-03,
             5.36429217e-03, -1.51237814e-02, -1.86295636e-02,  7.33730686e-03,
            -1.85398513e-02,  1.20012698e-02,  1.38622315e-02,  1.48138021e-02,
             4.54187744e-03,  8.56565364e-03, -4.50984784e-03, -4.16407036e-03,
            -3.78523864e-04,  5.17276764e-03, -1.07159807e-03, -5.34147242e-03,
            -4.64290015e-04, -2.55771720e-02,  2.87195933e-03, -5.38110333e-03,
             1.17607976e-02, -5.48219436e-03, -8.98229736e-03,  1.72540309e-02,
             3.34464763e-03, -1.02802491e-02,  6.04949772e-03, -4.29102799e-04,
             3.03261484e-03,  2.72251956e-03,  1.06288529e-02, -9.31323588e-04,
             3.63547515e-03,  1.11205136e-02,  8.68301853e-03, -1.52172758e-02,
             7.28648699e-03, -3.12303012e-03, -1.79116254e-02,  6.46336746e-04,
            -6.21691734e-03,  1.98643930e-02, -1.00060605e-02, -1.62491068e-03,
             9.47750036e-03,  2.76323114e-03, -6.58888517e-03, -2.66808567e-03,
             8.18821416e-03, -3.26934196e-03,  6.36322084e-03, -9.80971793e-03,
             1.43097690e-02, -3.24592005e-03, -2.93351502e-03,  4.66444163e-03,
            -1.04971944e-02,  5.32184795e-04,  1.22446871e-02,  3.68558230e-03,
             2.03293447e-03, -1.14694715e-02,  1.11611296e-02,  6.78077506e-03,
             9.15919965e-03, -6.48321204e-03,  9.14721124e-03,  9.81069558e-03,
             5.46416759e-04,  2.83696708e-03, -6.18660808e-03,  5.23040873e-03,
            -9.43232269e-03,  5.01350486e-04, -7.80010012e-03, -6.11710888e-03,
             5.13209665e-03, -4.63020314e-03, -1.83953692e-03, -4.59957165e-03,
            -1.81975961e-03,  1.13762265e-02,  5.89661920e-03,  3.79557880e-03,
            -7.71640983e-03, -1.23142688e-02,  5.73868261e-03,  1.34087970e-02,
             4.29147159e-03,  9.57904998e-03,  1.07044304e-02,  1.99221102e-02,
            -2.61706503e-03,  1.45577671e-02,  2.77056800e-03,  4.78140880e-03,
             1.20878142e-03,  3.42858245e-03, -1.69856130e-03,  7.83897376e-04,
            -6.53504210e-04,  4.18284416e-03, -9.40530680e-05, -1.60026261e-04,
            -3.15430401e-03, -4.17154650e-03,  1.54810025e-03, -7.50441948e-03,
             9.89481461e-05, -1.41143581e-02,  2.40869300e-03, -1.51957209e-03,
             1.16327294e-02, -9.55667123e-03,  1.44912348e-02, -8.35596990e-03,
             7.95354718e-03,  5.12660372e-03,  1.46623119e-02,  4.65114541e-03,
             1.74226145e-03, -5.67281482e-03, -5.50606280e-03,  4.15548027e-03,
             4.67841941e-03,  5.75094804e-03,  2.76135746e-03, -5.62818508e-03,
            -7.30893481e-04, -6.00692150e-03,  2.36115112e-03,  1.62798089e-02,
             4.02249298e-04,  1.47418653e-02,  7.20601078e-03,  1.14021295e-02,
             1.44496676e-02, -1.32249960e-02,  2.15247121e-03,  4.09362355e-03,
             1.45143285e-02,  1.31540261e-02,  1.55092259e-02, -4.18210240e-03,
             2.64852642e-03, -4.82464320e-03, -4.67698373e-03,  3.50518964e-03,
            -1.36248164e-02, -7.11345454e-03,  4.60594167e-03,  2.63389543e-03,
             1.34032875e-03, -1.29473325e-03, -3.35849685e-03,  1.67917512e-03,
             1.10303519e-02, -3.11467239e-05,  6.49283788e-03, -7.81435880e-03,
             1.76452284e-02,  1.04238370e-02, -4.69215195e-03,  5.81979544e-04,
            -1.98356723e-02, -5.70613891e-03,  8.49987248e-04,  1.68373729e-02,
             1.64191887e-02,  1.04608380e-02, -7.32840721e-03,  2.85896382e-03,
            -2.01301614e-04, -1.02907026e-02, -1.04965167e-02, -5.52796873e-04,
            -1.20299484e-04, -1.09925685e-03,  5.85160937e-03,  2.06569223e-02,
            -1.16742886e-02, -8.26894289e-03,  1.44519510e-02,  8.68505776e-03,
             1.19268832e-02,  5.26283060e-03,  2.20560264e-02,  1.00600326e-02,
            -1.88134070e-03, -1.69761724e-02, -6.26756005e-03, -3.38911679e-04,
             5.13680542e-03, -1.95889890e-03, -2.17098912e-03, -3.15865206e-03,
             1.47845344e-04,  7.96503637e-03, -6.92480846e-03,  3.70510666e-03,
             7.66695055e-03,  6.70358506e-03, -6.96085158e-03,  1.53629095e-02,
             1.86723844e-02, -7.97768542e-03,  1.13408760e-03,  1.64453412e-02,
             7.86555740e-03,  8.11357602e-03, -3.97513230e-03, -5.54770721e-03,
             1.86414221e-02,  1.63247283e-02, -7.57573261e-04,  3.04783372e-03,
             6.02811064e-03,  2.72770615e-04,  3.29787225e-03,  1.78250017e-02,
             8.11287070e-03,  1.23290614e-02,  3.27188272e-03, -8.57967903e-03,
             1.60769195e-02,  7.16771498e-03,  1.16647724e-02, -6.28664396e-03,
            -1.42586547e-02,  5.47605028e-03,  1.14969641e-03,  5.84922870e-04,
            -3.94756512e-03,  1.31725191e-02,  2.90039058e-03, -1.25577523e-02,
            -9.78917120e-03,  8.50944167e-03, -1.33852891e-02, -1.10724119e-03,
            -6.13591574e-03, -1.12463163e-02, -1.19716727e-02,  6.45789325e-03,
            -1.26230884e-02, -9.51967853e-03, -2.75361966e-03, -9.92746131e-03,
             1.45935455e-02,  4.65504099e-03,  3.31550071e-03, -1.04595145e-02,
             5.93363274e-03, -3.81389867e-04,  9.52299641e-03,  1.44994018e-02,
            -4.47748660e-03,  2.85882044e-03,  5.52815415e-03,  1.88440260e-02,
             1.69711364e-02, -7.93964903e-03, -9.42645940e-03,  1.36986903e-03,
            -2.12047186e-02, -9.85594315e-03, -1.25158026e-02,  9.31486036e-03,
            -1.52051804e-02,  1.69511449e-02, -1.91137863e-03, -1.35500385e-02,
             2.55165368e-03, -2.51141096e-02,  9.62349215e-03, -1.12377955e-02,
             9.98116223e-03, -1.07760556e-02, -1.69691111e-03,  5.61699226e-03,
            -1.22563609e-02,  5.54497848e-03, -4.46853532e-03, -1.39720663e-02,
             9.35346776e-03,  1.40185253e-02,  4.24760780e-03,  1.36261429e-02,
             1.79470501e-03,  1.41545587e-02,  6.32130317e-03, -1.15250463e-02,
            -8.54878402e-03, -1.64811988e-02,  1.90066621e-02,  1.54530384e-02,
            -5.55854794e-03,  1.26907388e-02,  5.39264936e-03, -1.06719618e-02,
            -1.22603262e-03, -5.33742688e-03, -1.26367239e-02,  1.92253332e-02,
            -2.95910850e-04,  8.77304966e-03, -5.68655143e-04,  1.38931266e-02,
             1.11144006e-02,  3.54186002e-04, -1.15603241e-03, -1.39468960e-02,
             1.29550223e-02,  1.13014075e-02,  1.31172510e-02,  2.47826652e-03,
            -9.42405051e-03, -2.58612186e-03, -5.29341447e-03, -1.04336783e-02,
             3.66736635e-03, -1.49207100e-02,  6.48964966e-03,  1.18624084e-03,
            -1.19865255e-02, -2.30032940e-02, -1.21192025e-02,  5.45466316e-04,
             1.08917380e-02, -5.32271281e-03, -1.62339706e-02,  5.28767036e-03,
             1.84310691e-03,  3.62564351e-03, -2.19418169e-04, -3.90339853e-03,
            -1.18366759e-02, -2.01445414e-02,  3.76155440e-03, -2.00397348e-03,
            -1.93908225e-04,  8.26558290e-03,  2.37206952e-03, -1.41144863e-02,
             2.40565301e-03, -8.41398971e-03,  4.25295067e-03,  6.57405357e-04,
            -1.38303645e-02,  8.42248375e-03, -3.40463762e-03, -5.49103642e-03,
            -8.01241591e-03, -2.25115870e-03, -6.17287000e-03, -1.46935803e-04,
            -1.63164622e-02,  1.60228803e-02,  1.76595427e-02, -1.12082168e-02,
             9.94637585e-04, -4.64384207e-03,  7.33769878e-03,  4.14695297e-03,
            -6.10681023e-03,  1.80445759e-02, -1.96333718e-03, -1.05458184e-02,
            -3.70002702e-03,  1.28312342e-02,  6.20129990e-03, -8.84648393e-03,
             5.92371651e-03,  2.55170845e-03,  1.52498118e-02,  1.04301101e-02,
            -1.03409919e-02, -1.58708793e-03, -3.73517551e-03, -1.30893179e-03,
            -7.22263076e-03,  1.13466126e-02, -7.20440767e-03, -1.12978988e-02,
             6.42180544e-03,  1.81445550e-02, -1.46089354e-02,  1.57596778e-02,
             1.01798202e-02,  6.24328365e-03,  7.05300350e-03, -8.92576614e-03,
            -5.20019107e-03, -2.01859488e-02, -5.75730127e-03, -7.83280726e-03,
             5.41992041e-03,  3.54268513e-03, -5.70186345e-03,  7.87650572e-03,
            -1.77062060e-02,  1.29850863e-03, -3.31465590e-03,  1.91965326e-02,
             1.79732193e-02, -6.32151857e-03,  7.32256338e-03, -2.07217350e-03,
            -1.50998518e-02, -1.82906671e-03, -5.47495891e-03, -2.03257958e-03,
            -4.06861838e-03, -7.15148229e-03, -3.64476684e-03,  6.81016866e-03,
            -1.38613009e-03,  5.10535689e-03,  1.06649154e-02, -4.20675519e-03,
             5.27032997e-03, -1.32378630e-02,  1.11323960e-02,  8.35729477e-03,
            -1.06519393e-02, -8.21342240e-03, -1.10285836e-02,  9.85257533e-03,
            -4.20384563e-03,  1.13903373e-02,  9.16708088e-03, -1.17642999e-03,
             6.08050951e-03, -9.38422557e-03,  6.79306611e-03, -2.19940829e-03,
            -1.08774752e-02, -9.61702798e-03,  8.93264036e-03, -5.04352703e-03,
            -1.01352665e-02, -1.50189916e-02,  4.42095519e-04, -1.59010711e-02,
            -1.20179188e-02,  1.75890783e-03, -1.23881319e-02,  1.18678946e-02,
             1.59798761e-03, -7.76694562e-03, -4.38607793e-03, -1.30714383e-02,
            -1.08108158e-02,  8.90569743e-03,  3.69362361e-04,  9.15744120e-03])}
    

    定义完梯度以后就可以开始实现梯度下降算法

    batch_size=100
    def train_batch(current_batch,parameters):
        grad_accu=grad_parameters(train_img[current_batch*batch_size+0],train_lab[current_batch*batch_size+0],parameters)
        for img_i in range(1,batch_size):
            grad_tmp=grad_parameters(train_img[current_batch*batch_size+img_i],train_lab[current_batch*batch_size+img_i],parameters)
            for key in grad_accu.keys():
                grad_accu[key]+=grad_tmp[key]
        for key in grad_accu.keys():
            grad_accu[key]/=batch_size
        return grad_accu
    
    def combine_parameters(parameters,grad,learn_rate):
        parameter_tmp=copy.deepcopy(parameters)
        parameter_tmp[0]['b']-=learn_rate*grad['b0']
        parameter_tmp[1]['b']-=learn_rate*grad['b1']
        parameter_tmp[1]['w']-=learn_rate*grad['w1']
        return parameter_tmp
    

    下面定义一些评估指标,方便后期可视化训练过程,就是把每一个epoch结束后的loss和accuracy等信息保存下来,后期可以通过分析这些信息来调整超参数,优化模型。

    def valid_loss(parameters):
        loss_accu=0
        for img_i in range(valid_num):
            loss_accu+=sqr_loss(valid_img[img_i],valid_lab[img_i],parameters)
        return loss_accu/(valid_num/10000)
    def valid_accuracy(parameters):
        correct=[predict(valid_img[img_i],parameters).argmax()==valid_lab[img_i] for img_i in range(valid_num)]
        return correct.count(True)/len(correct)
    def train_loss(parameters):
        loss_accu=0
        for img_i in range(train_num):
            loss_accu+=sqr_loss(train_img[img_i],train_lab[img_i],parameters)
        return loss_accu/(train_num/10000)
    def train_accuracy(parameters):
        correct=[predict(train_img[img_i],parameters).argmax()==train_lab[img_i] for img_i in range(train_num)]
        return correct.count(True)/len(correct)
        
    parameters=init_parameters()
    current_epoch=0
    train_loss_list=[]
    valid_loss_list=[]
    train_accu_list=[]
    valid_accu_list=[]
    

    准备工作全部完成,现在可以开始训练模型

    learn_rate=10**-0.6
    # learn_rate=1
    epoch_num=15
    for epoch_ in range(epoch_num):
        for i in range(train_num//batch_size):
            grad_tmp=train_batch(i,parameters)
            parameters=combine_parameters(parameters,grad_tmp,learn_rate)
        current_epoch+=1
        train_loss_list.append(train_loss(parameters))
        train_accu_list.append(train_accuracy(parameters))
        valid_loss_list.append(valid_loss(parameters))
        valid_accu_list.append(valid_accuracy(parameters))
    

    训练15轮,正确率达到90以上,接下来使用训练好的模型识别图片,随机实验多次以后发现效果还是相当不错的。

    test_index = np.random.randint(1000)
    show_test(test_index)
    predict_result = predict(test_img[test_index], parameters)
    print("predict:{}".format(predict_result.argmax()))
    
    predict.png

    接下来看看保存的训练过程

    # 可视化acc
    plt.plot(valid_accu_list, c='g', label='validation acc')
    plt.plot(train_accu_list, c='b', label='train acc')
    plt.legend()
    plt.savefig('train_acc.png')
    
    train_acc_1.png
    # 可视化loss
    plt.plot(valid_loss_list, c='g', label='validation_loss')
    plt.plot(train_loss_list, c='b', label='train_loss')
    plt.legend()
    plt.savefig('train_loss.png')
    
    train_loss_1.png

    我们还可以保存模型下一次使用,不过这里保存的模型仅仅是参数,而不是保存模型结构和参数,可以使用Python中的pickle

    # 保存参数
    import pickle
    model_prameters_name = 'Mnist_model.pkl'
    f = open(model_prameters_name, 'wb')
    pickle.dump(parameters, f)
    f.close()
    
    f = open(model_prameters_name, 'rb')
    param = pickle.load(f)
    print(param)
    f.close
    

    至此就在不使用框架的情况下完成了一个最简单的人工神经网络,相比使用keras实现的版本该有的部分基本都有了。框架对相关函数初始化的方式都做了一定的优化,可以加快训练速度,这都是需要很高的数学能力才可以完成。这里我们也参考相关参数优化的论文对这个模型进行一下优化。这里可以参考此论文Understanding the difficulty of training deep feedforward neural networks优化参数初始化部分,使用此种方式初始化参数后,模型可以更快收敛。

    glorot10a.png
    dimensions=[28*28,10]
    activation=[tanh,softmax]
    distribution=[
        {'b':[0,0]},
        {'b':[0,0],'w':[-math.sqrt(6/(dimensions[0]+dimensions[1])),math.sqrt(6/(dimensions[0]+dimensions[1]))]},
    ]
    

    相关的notebook已经上传至Github

    相关文章

      网友评论

          本文标题:numpy实现一个神经网络识别手写数字数据集

          本文链接:https://www.haomeiwen.com/subject/eiwtfrtx.html