美文网首页
K近邻算法原理及实战

K近邻算法原理及实战

作者: takumiCX | 来源:发表于2018-06-23 22:18 被阅读0次

    1.KNN算法原理

    KNN(k-nearest neighbor,k-NN),即K近邻算法,可以说是最简单直观的分类算法之一了。说它简单是因为K-近邻算法没有显式的学习过程,不像其他分类算法需要训练出分类模型;说它直观是因为它的原理极其容易被理解,容易让人联想到“近朱者赤,近墨者黑”这句俗语。它的原理可以用一句话概括:对于给定的输入实例x, 从训练数据集{(x1,y1),(x2,y2).....(xn,yn)}中找到和该输入实例最相似的K个实例,在这K个实例中按某种分类决策规则决定x所属的类别y。当然上面这句话还有几处容易让人产生疑问的地方:1.相似度是如何度量的?2.K值是如何选取的?3.分类决策规则是怎样的?

    • 相似度如何度量

    K近邻算法中是根据实例点之间的距离来度量它们之间的相似度的。通常使用欧几里得距离,但也可以是曼哈顿距离等其他距离,以特征空间为n为实数空间Rn为例。

    • 欧几里得距离的计算公式为:


      欧氏距离.png
    • 曼哈顿距离的计算公式为:


      曼哈顿距离.png
    • K值的选择

    较小的K值会导致模型对近邻的实例点非常敏感,增加噪声对分类结果的影响。k值得减小意味着整体模型变得复杂,容易发生过拟合。

    较大的K值可以减少学习的估计误差,但是会增大近似误差,容易发生预测错误。

    一般可以采用交叉验证法来选取最优的K值。

    • 分类决策规则

    常用的分类决策规则是多数表决规则,它指的是根据被选出的K个实例的类别中的多数决定输入样例所属的类别。

    2.KNN算法实战

    代码实现基于python3.6

    2.1.构造训练样本集

    
    from numpy import *
    import operator
    
    # 创建训练样本集和对应的标签向量 def createDataSet():
        group = array([[1.0, 1.1], [1.0, 1.0], [0, 1], [0, 0.1]])
        labels = ['A', 'A', 'B', 'B']
        return group, labels
    
    
    

    该训练样本集由两部分组成,一部分为group表示的特征向量矩阵,矩阵的每一行表示一个训练样本的特征向量,这里每个特征向量含有2个特征,group里共有4个样本。另一部分labels为类别标签向量,表示对应特征向量所属的类别。

    2.2.KNN分类算法实现

    
    # inX:输入向量,dataSet:训练样本集矩阵,labels:训练样本标签向量,k:分类k值   def classify0(inX, dataSet, labels, k):
      dataSetSize = dataSet.shape[0]  # 训练样本个数
      diffMat = tile(inX, (dataSetSize, 1)) - dataSet  # 输入向量与训练集矩阵特征差构成的矩阵
      sqDist = (diffMat ** 2).sum(axis=1)  # 矩阵以行为单位求和
      dist = sqDist ** 0.5 # 得到输入向量与各训练样本的欧式距离
      sortedIndicies = dist.argsort()  # 按距离从小到大排序,返回排序后对应的原始dist列表的索引值
      classCount = {}
      for i in range(k):
            label = labels[sortedIndicies[i]]  # 获得类别标签
      classCount[label] = classCount.get(label, 0) + 1 # 该类别个数+1
      sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 按类别个数降序排序
      return sortedClassCount[0][0]  # Set(Tuple) 返回类别 
    
    

    代码实现过程可以描述如下:

    1.计算训练集dataSet中的样本点(特征向量)与输入实例(inX)之间的欧几里得距离。

    2.选取距离最小的k个点。

    3.分别累计k个点表示的样本类别的出现次数。

    4.选取出现次数最多的类别为输入实例inX的预测类别。

    2.3.测试分类效果

    group, labels = createDataSet()
    print(classify0([0, 0], group, labels, 3))
    

    以输入向量[0,0]作测试,得到结果为类别B

    2.4 KNN预测约会对象魅力值

    如何预测约会对象的魅力值,可以从以下三个方面考虑(为什么是这几个方面?)
    1.每年获得的飞行常客里程数
    2.玩视频游戏所耗时间百分比
    3.每年消耗的冰激凌公升数

    而预测结果可以分为三个类别:不喜欢,魅力一般,极具魅力。

    # 输入为文件名字符串,输出为训练样本的特征矩阵和类别标签向量
    def file2matrix(str):
        fr= open(str)
        arrayOLines = fr.readlines()
        numberOfLines = len(arrayOLines)  # 训练样本个数
        returnMat = zeros((numberOfLines, 3))  # 训练样本个数X特征个数的特征矩阵,初始化为0
        classLabelVector = []  # 类别标签向量
        index = 0
        for line in arrayOLines:
            line = line.strip()  # 截掉所有的回车字符
            listFromLine = line.split('\t')
            returnMat[index, :] = listFromLine[0:3]  # 取每行的前三列赋值给特征矩阵的对应行
            classLabelVector.append(int(listFromLine[-1]))  # 取每行的最后一列赋值给类别标签向量
            index += 1
    
        return returnMat, classLabelVector
    
    
    datingMat, datingLabels = file2matrix('/Users/chengxu/PycharmProjects/ml-practice/com/chengxu/knn/datingTestSet2.txt')
    # print(datingMat)
    # print(datingLabels)
    
    # 绘图
    import matplotlib
    import matplotlib.pyplot as plt
    
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']  # 解决中文无法显示的问题
    fig = plt.figure()  # 创建画板(figure)
    ax = fig.add_subplot(111)  # 指定子视图的位置
    x1 = []
    x2 = []
    x3 = []
    y1 = []
    y2 = []
    y3 = []
    for i in range(datingMat.shape[0]):
        if datingLabels[i] == 1:
            x1.append(datingMat[i, 1])
            y1.append(datingMat[i, 2])
        elif datingLabels[i] == 2:
            x2.append(datingMat[i, 1])
            y2.append(datingMat[i, 2])
        else:
            x3.append(datingMat[i, 1])
            y3.append(datingMat[i, 2])
    
    t1 = ax.scatter(x1, y1, c='red')
    t2 = ax.scatter(x2, y2, c='blue')
    t3 = ax.scatter(x3, y3, c='green')
    ax.axis([0, 25, 0.0, 2.0])
    ax.legend([t1, t2, t3], ["不喜欢", "魅力一般", "极具魅力"], loc=2)
    plt.xlabel('打游戏所占时间百分比')
    plt.ylabel('每周消耗冰淇淋公升数')
    # ax.scatter(datingMat[:, 0], datingMat[:, 1], 15.0 * array(datingLabels), 15.0 * array(datingLabels))  # 取第2,第3列数据
    plt.show()
    
    
    # 数据归一化   (v-min)/(max-min)
    def autoNorm(dataSet):
        minValues = dataSet.min(0)  # 取每列的最小值
        maxValues = dataSet.max(0)  # 取每列的最大值
        ranges = maxValues - minValues  # 每列最大值与最小值之差
        normDataSet = zeros(shape(dataSet))  # 待归一化的新特征矩阵
        m = dataSet.shape[0]  # 样本总数
        normDataSet = dataSet - tile(minValues, (m, 1))  # 减去每列的最小值
        normDataSet = normDataSet / tile(ranges, (m, 1))  # 处以最大值与最小值之差
        return normDataSet, ranges, minValues
    
    
    # normDatingMat = autoNorm(datingMat)
    # print(normDatingMat)
    
    def datingClassTest():
        hoRatio = 0.10
        datingMat, datingLabels = file2matrix(
            '/Users/chengxu/PycharmProjects/ml-practice/com/chengxu/knn/datingTestSet2.txt')
        normMat, ranges, minValues = autoNorm(datingMat)
        m = normMat.shape[0]
        numTest = int(m * hoRatio)
        errCount = 0.0
        for i in range(numTest):
            classifyRes = classify0(normMat[i, :], normMat[numTest:m, :], datingLabels[numTest:m], 3)
            print("预测值:%d ,真实值:%d" % (classifyRes, datingLabels[i]))
            if (classifyRes != datingLabels[i]):
                errCount += 1.0
        print('错误率为%.2f'%(errCount/float(numTest)))
    
    
    # datingClassTest()
    
    
    def classifyPerson():
        resultList=['不喜欢','魅力一般','极具魅力']
        persontTats=float(input('玩游戏时间占比?'))
        ffMiles=float(input('每年飞行里程数?'))
        iceCream=float(input('每年消费的冰激淋公升数?'))
        datingMat,datingLabels=file2matrix('/Users/chengxu/PycharmProjects/ml-practice/com/chengxu/knn/datingTestSet2.txt')
        normMat,ranges,minVals=autoNorm(datingMat)
        inArr=array([ffMiles,persontTats,iceCream])
        classifyResult=classify0((inArr-minVals)/ranges,normMat,datingLabels,3)
        print("%s"%resultList[classifyResult-1])
    
    
    classifyPerson()
    

    参考资料

    1. 统计学习方法
    2. 机器学习实战

    相关文章

      网友评论

          本文标题:K近邻算法原理及实战

          本文链接:https://www.haomeiwen.com/subject/feidyftx.html