我:终于到周末了,可以休息一下了!!!来几把LOL!!!
(叮铃......叮铃......叮铃......)
我:喂,老板啊?怎么啦
老板:小韩啊,在家休息吗?
我:是啊。
老板:别休息啦,来加个班,用上次你写的kNN,做一个手写识别系统,训练集和测试集我都发你邮箱了!周日晚上给我!
我:(What???大周末的,你让我加班,老子不干了!)行,保证写出来!
行了行了,周末不休息了,开工!
这次我们要构建一个手写识别系统,为了简单,我们就只识别0-9。需要识别的数字已经用图形处理软件,处理成具有相同的色彩和大小:宽高是32像素×32像素的黑白图像。尽管采用文本格式存储图像不能有效地利用内存空间,但是为了方便我们的理解,我们还是将图像转换为文本格式。示例如下:
image然后,我们来看一下,使用kNN构造手写识别系统的步骤:
- 收集数据:提供文本文件。
- 准备数据:编写函数classify0(),将图像格式转换为分类器使用的list格式。
- 分析数据:在Python命令提示符中检查数据,确保它符合要求。
- 训练算法:此步骤不适用于k-近邻算法。
- 测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
- 使用算法:本例没有完成此步骤,若你感兴趣可以构建完整的应用程序,从图像中提取数字,并完成数字识别,美国的邮件分拣系统就是一个实际运行的类似系统。
2.3.1 准备数据:将图像转换为测试向量
老板给的训练集在目录trainingDigits中,其中包含了大约2000个例子,每个数字大概有200个样本。测试集在目录testDigits中,其中大约900个测试数据。截图如下:
image image每个文本文件名称下划线前的数字代表这个文本文件所代表数字。比如说0_8.txt代表的是数字0的第9个样本(从0开始计数)。
为了使用我们先前编写好的分类器,我们必须将图像格式化处理为一个向量。我们将一个32×32的二进制图像矩阵转换为1×1024的向量。
好了,代码走起来!我们继续在kNN.py中编写函数img2vector,代码如下:
def img2vector(filename):
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
代码很简单,就是将原来32×32转换成1×1024,这里我也就不多说什么了。大家可以自己去测试一下效果。
2.3.2 使用k-近邻算法识别手写数字
上一节我们已经把数据处理成我们想要的格式了,那么接下来我们就可以将这些数据丢到分类器里了。直接来看代码:
def handwritingClassTest():
# 1.初始化我们所需要的数据
hwLabels = []
trainingFileList = os.listdir('trainingDigits') # 这里需要我们提前导入os模块,listdir可以列出给定目录下的文件名
m = len(trainingFileList) # 获得训练样本数目
trainingMat = zeros((m, 1024)) # 构造m×1024的矩阵
# 2.循环遍历训练集中的每个文件,生成每个数字的向量信息,保存在trainingMat中
for i in range(m):
fileNameStr = trainingFileList[i] # 获得文件名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) # 获得该文件所代表的数字
hwLabels.append(classNumStr) # 将文件所代表的数字其存放在类别标签中
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr) # 数据转换
# 3.遍历测试数据文件夹,使用kNN进行测试。
testFileList = os.listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList) # 获得测试样本数目
for i in range(mTest):
fileNameStr = testFileList[i] # 获得文件名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) # 获得该文件所代表的数字
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) # 分类
print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
if classifierResult != classNumStr:
errorCount += 1.0
print('\nthe total number of errors is: %d' % errorCount)
print('\nthe total error rate is: %f' % (errorCount / float(mTest)))
上面代码也不难,每一步的具体含义我都给大家写在注释中了,所以我也就不多说了。
依赖于机器速度,加载数据集可能要花费很长时间,然后函数开始依次测试每个文件,我们直接来看输出的结果:
image我们使用k-近邻算法识别手写数字数据集,错误率为1.2%。
改变变量k的值、修改函数handwritingClassTest随机选取训练样本、改变训练样本的数目,都会对k-近邻算法的错误率产生影响,感兴趣的话可以改变这些变量值,观察错误率的变化。
但是,我们需要注意的是,实际使用这个算法时,算法的执行效率并不高。原因如下:
- 算法需要为每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计要执行900次,
- 此外,我们还需要为测试向量准备2MB的存储空间。
2.4 小结
kNN的理论、实战,我们就讲到这里了,下面我们来总结一下:
- k-近邻算法是分类数据最简单最有效的算法,我们通过两次实战讲述了如何使用k-近邻算法构造分类器。
- k-近邻算法是基于实例的学习,使用算法时我们必须有接近实际数据的训练样本数据。
- k-近邻算法必须保存全部数据集,如果训练数据集的很大,必须使用大量的存储空间。此外, 由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。
- k-近邻算法的另一个缺陷是它无法给出任何数据的基础结构信息,因此我们也无法知晓平均实例样本和典型实例样本具有什么特征。
好了,k-近邻算法我们就讲到这里,因为是最基础的,所以用了比较多的篇幅,希望大家能够慢慢看完,对机器学习先有一个感性的认识。
机器学习的路还很长,加油,冲冲冲!!!
最后,还是熟悉的配方!
欢迎大家关注我的公众号,有什么问题也可以给我留言哦!
image
网友评论