教材选用《统计学习方法》,第一版,李航著;
代码取自《机器学习实战》,人民邮电出版社;
K-近邻算法(K-NearestNeighbours)是一种最简单的基本分类方法,它的定义为:在给定训练集中,找到与新的输入实例最相近的K个实例,新实例的类型就被划分为这K个实例多数属于的类型。
K-近邻算法的模型实际对应特征空间的划分,模型由距离度量、K值的选择和分类决策规则所决定。
距离度量一般使用欧氏距离,不同的距离度量所确定的最近邻点是不同的。距离度量相关的资料点这里。
K值的选择会对模型结果产生重大影响,在应用中K一般选一个较小的值,再采用交叉验证法找出最优的K值。
分类决策规则一般选用多数表决,多数表决规则等价于经验风险最小化。
K-近邻算法最简单的实现方法是线性扫描,但考虑到效率问题采用kd树实现可以优化速度,kd树更适用于训练实例远大于空间维数时的K-近邻搜索。
![](https://img.haomeiwen.com/i4964701/1d607fd44f5aa42a.png)
![](https://img.haomeiwen.com/i4964701/ce1af4fdcff2a8f5.png)
以下为K-近邻算法的简单实现例子,首先将以下4个数据点分为蓝和红两类,然后通过K-近邻算法找出新输入的点属于哪一类。
1.使用Python导入数据
首先导入科学计算包numpy和operator模块,然后建立数据集,样本为4个点,前两个设为红点,后两个设为蓝点。(运行环境为Python3.6)
from numpy import *
import operator
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['r','r','b','b']
return group, labels
为了更清晰地看出数据之间的关系,通过数据可视化包matplotlib绘制出散点图。散点图的绘制教程点这里。
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(group[:,0],group[:,1],15,labels)
plt.show()
![](https://img.haomeiwen.com/i4964701/1f9183f51b041614.png)
2.K-近邻算法的实现
接下来定义分类算法classify0(),下面是Python代码,然后会详细解释每行代码的含义。
from os import listdir
def classify0(inX, dataSet, labels, k):
#1 计算欧氏距离,见下方计算公式
dataSetSize = dataSet.shape[0] #获取数据形状
diffMat = tile(inX, (dataSetSize,1)) - dataSet #用tile()重复计算新输入分类点和样本点之前的差值数组
sqDiffMat = diffMat**2 #将差值数组平方
sqDistances = sqDiffMat.sum(axis=1) #将平方后的数组累加
distances = sqDistances**0.5 #开方得欧氏距离distances
sortedDistIndicies = distances.argsort() #argsort()按从小到大顺序排序所有点距离
#2 得到主要分类
classCount={} #新建classCount数组
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]] #取出数组标签
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #按前k个投票给点,得到主要标签
#3 返回频率最高的分类
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
classify0有四个输入:inX为新输入的分类点,dataSet为样本训练集(4个数据点),labels为分类标签(蓝和红),k为选择近邻的参数。
1.按照公式计算新输入点和所有样本点之间的欧氏距离;
![](https://img.haomeiwen.com/i4964701/b10ad8336b1cd926.png)
2.按照从小到大的次序进行排序,确定前k个距离最小元素的主要分类;
3.将classcount分解为元组列表,按照第二个元素从大到小排序,返回频率最高的元素标签。
3.测试分类器classify0
为了测试数据所有分类,在命令行中输入如下命令:
print(classify0([0.8,0.8], group, labels, 3)) #测试用了(0.8,0.8)这个点,k值用了3
在调试模式下运行命令,查看变量的属性。dataSetSize 为4,代表样本4个点;diffMat 显示了前两个点的差值( -0.2 = 0.8 - 1)。
![](https://img.haomeiwen.com/i4964701/75856a9ea72f474c.png)
sqDiffMat 为 diffMat数组的平方; sqDiffMat 为 sqDiffMat 数组中的平方值相加;distances 为sqDiffMat 数组的开方;sortedDistIndicies 为distances 进行了排序,并取索引为新数组。
![](https://img.haomeiwen.com/i4964701/fcb81f966490cf9e.png)
![](https://img.haomeiwen.com/i4964701/d90632ea0cafeee3.png)
![](https://img.haomeiwen.com/i4964701/4bf32e745e708bfb.png)
classcount统计出样本空间里前k=3个距离最近的点,红色的为2个,蓝色的1个;sortedclasscount为排序后的标签数组,因此预测新输入的点属于红色类别。
![](https://img.haomeiwen.com/i4964701/087f911970c897df.png)
![](https://img.haomeiwen.com/i4964701/3a2dbe0d007e360e.png)
![](https://img.haomeiwen.com/i4964701/bd85ef4839f71911.png)
![](https://img.haomeiwen.com/i4964701/906830ffbf46a464.png)
网友评论