美文网首页机器学习机器学习
机器学习之python之手写数字识别,超详细的教程(二十一)

机器学习之python之手写数字识别,超详细的教程(二十一)

作者: SundayCoder | 来源:发表于2018-01-07 14:19 被阅读0次

这段时间做了很多的准备工作。
接下来就开始实现机器学习中的‘hello world’
也就是手写数字的识别。
这里给大家介绍一个网站https://www.kaggle.com

在进入这个网站之前,推荐你去看一下https://www.zhihu.com/question/23987009
知乎上有人对这个网站的一些建议。

它上面有很多的训练集和测试集合,而且他们的数据集时不时的更新。
各个用户在上面交流,是学习机器学习的好地方。
手写数字识别数据集是非常著名的数据集。
介绍地址: https://www.kaggle.com/c/digit-recognizer
训练集和测试集下载地址:https://www.kaggle.com/c/digit-recognizer/data

下载之前必须先注册,但是注册会有问题,需要翻墙工具来解决这个问题:


QQ截图20180107121840.jpg

解决之后:


QQ截图20180107122030.jpg

代码块


#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Author  : SundayCoder-俊勇
# @File    : digitrecognitionTest.py

# 这段时间做了很多的准备工作。
# 接下来就开始实现机器学习中的‘hello world’
# 也就是手写数字的识别。
# 这里给大家介绍一个网站https://www.kaggle.com

# 在进入这个网站之前,推荐你去看一下https://www.zhihu.com/question/23987009
# 知乎上有人对这个网站的一些建议。

# 它上面有很多的训练集和测试集合,而且他们的数据集时不时的更新。
# 各个用户在上面交流,是学习机器学习的好地方。
# 手写数字识别数据集是非常著名的数据集。
# 介绍地址:
# https://www.kaggle.com/c/digit-recognizer
# 训练集和测试集:https://www.kaggle.com/c/digit-recognizer/data

import csv
from sklearn import neighbors
# (1)引入csv是因为训练集与测试集都是csv文件,所以需要使用
# csv模块进行文件的读写。
# 如果不熟悉csv请参考上一次的教程、弄懂了在过来。

# (2)从sklearn中引用neighbors。neighbors翻译过来是邻近的意思,
# 导入它是因为我们这次的手写数字识别是基于KNN算法,
# KNN算法是临近分类算法中的一种。
# 至于KNN算法是什么?
# 可以去参考一些博客,这个算法的原理到后面我也会讲的。
# 这节课只是为了实现‘hello world’而已。具体的算法原理可以参考后面的教程。

# 接下来需要做的工作:
# (1)导入训练数据和测试数据(使用csv模块)
# (2)使用训练集训练KNN。
# (3)使用测试集测试训练的KNN的正确率。
# (4)把结果写入一个csv文件之中。

# 导入训练数据和测试数据
#导入训练数据和测试数据
def loadData(filename1,filename2,trainDataSet,trainTargetSet,testDataSet):
    # 处理训练集csv文件。
    with open(filename1,'r') as csvfile1:
        csv_reader= csv.reader(csvfile1)
        dataSet = list(csv_reader)
        # 到这里的功能实现可以参考下面的图1.

        for x in range(1,len(dataSet)):
            # 不能从0开始是因为0开始的是标题而不是数据。
            temp = []
            #temp是一个空列表,其作用是保存每一行的所有列的数据。

            # 为什么是dataSet[x][0]因为这是一个大列表里面
            # 训练集的每一行又是一个小列表,按照嵌套列表的访问规则来获取数据
            # 参考下面的图2。

            # 注意:训练集的第一列为lable,也就是已经分类好的lable。
            # 参考下面的图三。

            # dataSet[x][0]得到第一列的lable并将转化为int类型。
            dataSet[x][0] = int(dataSet[x][0])
            # trainTargetSet是一个空的列表,这里把dataSet[x][0]
            # 列表中
            trainTargetSet.append(dataSet[x][0])
            # 理解了X代表的是行数,第一行也就是x=0代表标题。【不可用】
            # 那y呢?y代表的列数由于第一列是lable所以第一列也不可用
            # 那785是怎么来的?
            # 因为训练集有784列,第一列的数据标题为lable,其他的为pixel+(列数-1)
            # range这个函数功能只能取1到784。【比785少一】
            # 所以y最大为785
            for y in range(1,785):
                # dataSet[x][y]取到每一行中的每一列的数据
                # 并将其转换为int类型。
                dataSet[x][y] = int(dataSet[x][y])
                temp.append(dataSet[x][y])
                # temp是一个空列表,其作用是保存每一行的所有列的数据。
            # 这里for循坏结束之后,temp就保存了x行的所有列的数据。
            # 这样在把每一行的数据加到trainDataSet这个列表之中。
            trainDataSet.append(temp)
    # 接下来按照差不多的方法处理测试集文件csv
    with open(filename2,'r') as csvfile2:
        csv_reader2 = csv.reader(csvfile2)
        dataSet2 = list(csv_reader2)
        for x in range(1,len(dataSet2)):
            temp = []
            # 这里的y从0开始到783,是因为测试集中的第一行不是lable
            # 你可以代开测试集来看一下就知道了.
            for y in range(784):
                dataSet2[x][y] = int(dataSet2[x][y])
                temp.append(dataSet2[x][y])
            testDataSet.append(temp)
    # 所有的数据填充好之后返回
    return trainDataSet,trainTargetSet,testDataSet

# 这个函数的功能就是把测试的结果保存在一个csv文件中
def saveResult(result):
#结果保存的路径
    with open(r'result.csv','wb') as myFile:
        myWriter=csv.writer(myFile)
        # 因为机器学习预测完之后返回的结果是一个列表。
        # 列表里面有识别的数字。
        x=0
        # 加入标题
        list11=['ImageId','Label']
        myWriter.writerow(list11)
        for i in result:
            x += 1#第一行的行号
            tmp=[x]
            tmp.append(i)#行号之后是预测的数据
            # 写入到文件中
            myWriter.writerow(tmp)
def main():
    trainDataSet = []
    trainTargetSet = []
    testDataSet = []
    # 这三个一开始均为空的列表。
    # 通过loadData函数导入数据到三个列表中去。
    print("开始加载数据")
    #训练数据和测试数据的路径
    loadData(r'train.csv', r'test.csv', trainDataSet, trainTargetSet, testDataSet)
    # (1)加载数据完之后得到KNN算法.
    knn = neighbors.KNeighborsClassifier()
    print("数据加载完毕,开始训练模型")
    # (2)使用knn算法训练模型.
    knn.fit(trainDataSet,trainTargetSet)
    print("模型训练完毕,开始预测")
    # (3)预测模型结果。
    # 这三步基本大多数的机器学习都遵循着三步步骤。
    prediction = knn.predict(testDataSet)
    print("预测结果:", prediction)
    print("打印完毕,开始保存")
    saveResult(prediction)
    print("保存完毕")
if __name__ == '__main__':
    main()

图一:每一行的数据又是一个小列表,也就是嵌套列表


list.jpg

图二:为什么x从1开始而不是0?


list2.jpg
图三:
train.jpg

图四:


QQ截图20180107122517.jpg
图五:
y.jpg
结果:
跑了之后需要时间才能正确的预测: 跑了之后会有一段时间在加载数据.jpg

结果:


结果.jpg

加油一起学习,更新完毕

相关文章

网友评论

    本文标题:机器学习之python之手写数字识别,超详细的教程(二十一)

    本文链接:https://www.haomeiwen.com/subject/dqwpnxtx.html