这段时间做了很多的准备工作。
接下来就开始实现机器学习中的‘hello world’
也就是手写数字的识别。
这里给大家介绍一个网站https://www.kaggle.com
在进入这个网站之前,推荐你去看一下https://www.zhihu.com/question/23987009
知乎上有人对这个网站的一些建议。
它上面有很多的训练集和测试集合,而且他们的数据集时不时的更新。
各个用户在上面交流,是学习机器学习的好地方。
手写数字识别数据集是非常著名的数据集。
介绍地址: https://www.kaggle.com/c/digit-recognizer
训练集和测试集下载地址:https://www.kaggle.com/c/digit-recognizer/data
下载之前必须先注册,但是注册会有问题,需要翻墙工具来解决这个问题:
QQ截图20180107121840.jpg
解决之后:
QQ截图20180107122030.jpg
代码块
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Author : SundayCoder-俊勇
# @File : digitrecognitionTest.py
# 这段时间做了很多的准备工作。
# 接下来就开始实现机器学习中的‘hello world’
# 也就是手写数字的识别。
# 这里给大家介绍一个网站https://www.kaggle.com
# 在进入这个网站之前,推荐你去看一下https://www.zhihu.com/question/23987009
# 知乎上有人对这个网站的一些建议。
# 它上面有很多的训练集和测试集合,而且他们的数据集时不时的更新。
# 各个用户在上面交流,是学习机器学习的好地方。
# 手写数字识别数据集是非常著名的数据集。
# 介绍地址:
# https://www.kaggle.com/c/digit-recognizer
# 训练集和测试集:https://www.kaggle.com/c/digit-recognizer/data
import csv
from sklearn import neighbors
# (1)引入csv是因为训练集与测试集都是csv文件,所以需要使用
# csv模块进行文件的读写。
# 如果不熟悉csv请参考上一次的教程、弄懂了在过来。
# (2)从sklearn中引用neighbors。neighbors翻译过来是邻近的意思,
# 导入它是因为我们这次的手写数字识别是基于KNN算法,
# KNN算法是临近分类算法中的一种。
# 至于KNN算法是什么?
# 可以去参考一些博客,这个算法的原理到后面我也会讲的。
# 这节课只是为了实现‘hello world’而已。具体的算法原理可以参考后面的教程。
# 接下来需要做的工作:
# (1)导入训练数据和测试数据(使用csv模块)
# (2)使用训练集训练KNN。
# (3)使用测试集测试训练的KNN的正确率。
# (4)把结果写入一个csv文件之中。
# 导入训练数据和测试数据
#导入训练数据和测试数据
def loadData(filename1,filename2,trainDataSet,trainTargetSet,testDataSet):
# 处理训练集csv文件。
with open(filename1,'r') as csvfile1:
csv_reader= csv.reader(csvfile1)
dataSet = list(csv_reader)
# 到这里的功能实现可以参考下面的图1.
for x in range(1,len(dataSet)):
# 不能从0开始是因为0开始的是标题而不是数据。
temp = []
#temp是一个空列表,其作用是保存每一行的所有列的数据。
# 为什么是dataSet[x][0]因为这是一个大列表里面
# 训练集的每一行又是一个小列表,按照嵌套列表的访问规则来获取数据
# 参考下面的图2。
# 注意:训练集的第一列为lable,也就是已经分类好的lable。
# 参考下面的图三。
# dataSet[x][0]得到第一列的lable并将转化为int类型。
dataSet[x][0] = int(dataSet[x][0])
# trainTargetSet是一个空的列表,这里把dataSet[x][0]
# 列表中
trainTargetSet.append(dataSet[x][0])
# 理解了X代表的是行数,第一行也就是x=0代表标题。【不可用】
# 那y呢?y代表的列数由于第一列是lable所以第一列也不可用
# 那785是怎么来的?
# 因为训练集有784列,第一列的数据标题为lable,其他的为pixel+(列数-1)
# range这个函数功能只能取1到784。【比785少一】
# 所以y最大为785
for y in range(1,785):
# dataSet[x][y]取到每一行中的每一列的数据
# 并将其转换为int类型。
dataSet[x][y] = int(dataSet[x][y])
temp.append(dataSet[x][y])
# temp是一个空列表,其作用是保存每一行的所有列的数据。
# 这里for循坏结束之后,temp就保存了x行的所有列的数据。
# 这样在把每一行的数据加到trainDataSet这个列表之中。
trainDataSet.append(temp)
# 接下来按照差不多的方法处理测试集文件csv
with open(filename2,'r') as csvfile2:
csv_reader2 = csv.reader(csvfile2)
dataSet2 = list(csv_reader2)
for x in range(1,len(dataSet2)):
temp = []
# 这里的y从0开始到783,是因为测试集中的第一行不是lable
# 你可以代开测试集来看一下就知道了.
for y in range(784):
dataSet2[x][y] = int(dataSet2[x][y])
temp.append(dataSet2[x][y])
testDataSet.append(temp)
# 所有的数据填充好之后返回
return trainDataSet,trainTargetSet,testDataSet
# 这个函数的功能就是把测试的结果保存在一个csv文件中
def saveResult(result):
#结果保存的路径
with open(r'result.csv','wb') as myFile:
myWriter=csv.writer(myFile)
# 因为机器学习预测完之后返回的结果是一个列表。
# 列表里面有识别的数字。
x=0
# 加入标题
list11=['ImageId','Label']
myWriter.writerow(list11)
for i in result:
x += 1#第一行的行号
tmp=[x]
tmp.append(i)#行号之后是预测的数据
# 写入到文件中
myWriter.writerow(tmp)
def main():
trainDataSet = []
trainTargetSet = []
testDataSet = []
# 这三个一开始均为空的列表。
# 通过loadData函数导入数据到三个列表中去。
print("开始加载数据")
#训练数据和测试数据的路径
loadData(r'train.csv', r'test.csv', trainDataSet, trainTargetSet, testDataSet)
# (1)加载数据完之后得到KNN算法.
knn = neighbors.KNeighborsClassifier()
print("数据加载完毕,开始训练模型")
# (2)使用knn算法训练模型.
knn.fit(trainDataSet,trainTargetSet)
print("模型训练完毕,开始预测")
# (3)预测模型结果。
# 这三步基本大多数的机器学习都遵循着三步步骤。
prediction = knn.predict(testDataSet)
print("预测结果:", prediction)
print("打印完毕,开始保存")
saveResult(prediction)
print("保存完毕")
if __name__ == '__main__':
main()
图一:每一行的数据又是一个小列表,也就是嵌套列表
list.jpg
图二:为什么x从1开始而不是0?
list2.jpg
图三:
train.jpg
图四:
QQ截图20180107122517.jpg
图五:
y.jpg
结果:
跑了之后需要时间才能正确的预测: 跑了之后会有一段时间在加载数据.jpg
结果:
结果.jpg
加油一起学习,更新完毕
网友评论