美文网首页python自学
神经网络实现手写数字识别(吴恩达课程Octave代码用pytho

神经网络实现手写数字识别(吴恩达课程Octave代码用pytho

作者: Damon0626 | 来源:发表于2018-12-06 23:03 被阅读2次

详细代码参考github

神经网络实现手写数字识别

实例:利用神经网络实现手写数字的识别,网络已经训练好,权重参数已给出。

1.载入数据和权重

由于给出了权重的参数,即神经网络已经训练好了,我们直接拿权重对现有的图片进行预测。

载入输入数据:
参考代码:

def loadData(self, path):
       self.data = scio.loadmat(path)
       self.x = self.data["X"]  # (5000, 400)  # 原100训练
       self.y = self.data["y"]  # (5000, 1)
       index = random.sample([i for i in range(5000)], 100)  # 随机100个没有重复的数字
       self.pics = self.x[index, :]  # (100, 400)

载入权重参数
参考代码:

def loadWeights(self, path):
       weights = scio.loadmat(path)
       self.theta1 = weights['Theta1']  # 25*401
       self.theta2 = weights['Theta2']  # 10*26
2.神经网络构建

神经网络共3层,输入层,1层隐藏层,输出层:输入层401个输入(第1个为1), 隐藏层26个单元,输出层10个单元(对应着0-9),如下图

在这里插入图片描述
3.对全部数据进行准确率验证

利用吴老师训练好的结果,进行验证,准确率97.52%。

参考代码:

def predictNN(self):
       x = np.hstack([np.ones((self.x.shape[0], 1)), self.x])  # 5000*401
       x1 = self.sigmoid(x.dot(self.theta1.T))  # (5000, 401)*(401, 25)

       x1_mid = np.hstack([np.ones((x1.shape[0], 1)), x1])
       x2 = self.sigmoid(x1_mid.dot(self.theta2.T))  # (5000, 26)*(26, 10)
       position = np.argmax(x2, axis=1) + 1   # 预测值
       accuracy = np.mean(position.reshape(5000, 1) == self.y) * 100
       print("神经网络准确率是:{}".format(accuracy))  # 97.52%
4.随机抽出一张图片,对图片中的数字进行验证

从5000张图片中随机抽取一张,利用神经网络计算得到预测结果。将预测结果做成图片的title进行显示,关闭图片后,提示若继续验证,请按回车;退出请按q键,展示两次的预测结果,可以看到已准确识别!

预测7 预测5
参考代码:
def predictOne(self, image):
       x = np.hstack([np.ones((image.shape[0], 1)), image])  # 1*401
       x1 = self.sigmoid(x.dot(self.theta1.T))  # (1, 401)*(401, 25)

       x1_mid = np.hstack([np.ones((x1.shape[0], 1)), x1])
       x2 = self.sigmoid(x1_mid.dot(self.theta2.T))  # (1, 26)*(26, 10)

       position = np.argmax(x2, axis=1) + 1
       return position
       
def displayTestPics(self, image):
       max_val = np.max(np.abs(image))
       im = image.reshape((20, 20)).transpose()/max_val*255
       predict_result = self.predictOne(image)
       plt.xticks([])
       plt.yticks([])
       # 由于0用10表示,为了显示准确,取了余数.
       plt.title("The Prediction Result is {}!".format(np.mod(predict_result[0], 10)), color='r', fontsize=20)
       plt.imshow(im, cmap='gray')
       plt.show()

相关文章

网友评论

    本文标题:神经网络实现手写数字识别(吴恩达课程Octave代码用pytho

    本文链接:https://www.haomeiwen.com/subject/jiaacqtx.html