美文网首页
mnist手写数字模型训练和识别

mnist手写数字模型训练和识别

作者: Echoooo_o | 来源:发表于2019-05-17 18:08 被阅读0次
    # coding: utf-8
    import sys, os
    sys.path.append(os.pardir)
    
    import numpy as np
    from dataset.mnist import load_mnist
    from two_layer_net import TwoLayerNet
    from PIL import Image
    
    # 读入数据
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
    
    network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)
    
    iters_num = 10000
    train_size = x_train.shape[0]
    batch_size = 100
    learning_rate = 0.1
    
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    
    iter_per_epoch = max(train_size / batch_size, 1)
    
    for i in range(iters_num):
        batch_mask = np.random.choice(train_size, batch_size)
        x_batch = x_train[batch_mask]
        t_batch = t_train[batch_mask]
        
        # 梯度
        #grad = network.numerical_gradient(x_batch, t_batch)
        grad = network.gradient(x_batch, t_batch)
        
        # 更新
        for key in ('W1', 'b1', 'W2', 'b2'):
            network.params[key] -= learning_rate * grad[key]
        
        loss = network.loss(x_batch, t_batch)
        train_loss_list.append(loss)
        
        if i % iter_per_epoch == 0:
            train_acc = network.accuracy(x_train, t_train)
            test_acc = network.accuracy(x_test, t_test)
            train_acc_list.append(train_acc)
            test_acc_list.append(test_acc)
            print(train_acc, test_acc)
    
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))
    
    def softmax(a):
        exp_a = np.exp(a)
        sum_exp_a = np.sum(exp_a)
        y = exp_a/sum_exp_a
        return y
    
    def predict(network, x):
        W1, W2 = network.params['W1'], network.params['W2']
        b1, b2 = network.params['b1'], network.params['b2']
    
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        y = softmax(a2)
    
        print(np.argmax(y))
    
    def load_image(file):
        im = Image.open(file)
        im = np.array(im)
        im = im.reshape(784,)
        return im
    
    img = load_image("5.bmp")
    print("the image maybe : ")
    predict(network, img)
    

    需要的包还有导入的文件见百度云:
    链接: https://pan.baidu.com/s/1xEDyt1-8k3-5nqzOuJLUFg
    提取码: ktpm

    相关文章

      网友评论

          本文标题:mnist手写数字模型训练和识别

          本文链接:https://www.haomeiwen.com/subject/hpeoaqtx.html