今天,写下一篇文章记录神经网络训练的一些细节
话说,我接触神经网络也没有多久,自己用java来实现了一个全连接的神经网络,虽然知道反向传播怎么推导,误差怎么计算,但代码实现起来,又是另外一回事。代码写好了呢,训练起来又是一回事,呵呵
细节1:在手写数字编码MNIST识别的过程中,怎么判断最终的识别数据
识别的时候,输入的数字图片如果是1,那么期望值是
[0,1,0,0,0,0,0,0,0,0],
但实际的结果可能是
[0.09799660858917661,0.10635506730467818,0.10067860599902455,0.9994167927288341,0.0986174993962114,0.9999550189100517,0.9995804178441895,0.9998872790037284,0.09755745682858019,0.09614806053433142]
那么,我们怎么判断神经网络的输出值到底是多少呢,我的处理方法是这样的:输出结果是一个数组,将它进行排序,取最大值,作为神经网络的识别结果。
通俗的解释就是,输入一个手写数组图片,识别结果既像8又像9,但像9更多一些,那就认为它的识别结果是9
细节2:计算识别成功率
因为是自己写的神经网络,为了避免神经网络的过拟合,我在神经网络的持久化信息里面加入了版本信息version,我每训练一次,版本好就会自增1,并在硬盘的某个目录下,新增一个神经网络文件
比如,我训练的神经网络叫"784-16-16-10"(28*28的输入层,两个隐含层,一个输出层),第一次训练后,硬盘上会多出一个"784-16-16-10.0"的文件,第二次会多出一个"784-16-16-10.1"的文件,同时每个批次运行结束后,会做一次0-9个手写数字的识别,并将识别结果+神经网络版本号打印到日志文件中去。
这样,我测试的时候,会将这些神经网络文件依次加载,用来测试,得到每个神经网络版本的识别成功率。哈哈,这样我就不怕过拟合了,而且可以优中选优,《算法导论》中动态规划没有白学喔!
下面再来说一下识别正确率的事情。MNIST提供了60000个训练数据,10000个测试数据。识别正确率 = 识别正确的数据 / 10000(总的测试数据数量)
最后,附上MNIST的下载地址,因为我是用java嘛,所以再附上java解析MNIST的方法。嘿嘿
网友评论