美文网首页
kNN(k-NearestNeighbor)算法

kNN(k-NearestNeighbor)算法

作者: 统计学徒 | 来源:发表于2018-12-22 00:05 被阅读0次

    参考《数据挖掘10大算法》对kNN算法进行基本总结,附有一个Python3的简例。

    基本思想
    从训练集中找出 k 个最接近测试对象的训练对象,再从这 k 个对象中找出居于主导的类别,将其赋给测试对象。

    定位
    由于这种总体占优的决策模式,对于类域的交叉、重叠较多的或者多模型、多标签的待分样本集来说,kNN方法较其他方法更为适合。kNN算法属于有监督学习的分类算法。

    避开了两个问题
    (1)分类时对象之间不可能完全匹配(kNN方法计算的是对象之间的距离);
    (2)具有相同属性的对象有不同的类别(kNN方法依据总体占优的类别进行决策,而不是单一对象的类别进行决策)。

    需要考虑几个关键要素
    (1)训练集;
    (2)用于计算对象之间临近的程度或者其他相似的指标;
    (3)最近邻的个数 k;
    (4)基于 k 个最近邻及其类别对目标对象类别进行判定的方法。

    kNN方法很容易理解和实现,在一定条件下,其分类错误率不会超过最优贝叶斯错误率的两倍。一般情况下,kNN方法的错误率会逐渐收敛到最优贝叶斯错误率,可以用作后者的近似。

    基本算法

    image.png

    算法的存储复杂度为O(n),时间复杂度为O(n),其中 n 为训练对象的数量。

    影响kNN算法性能的几个关键因素
    (1)k 值的选择;
    如果 k 值选得过小,结果就会对噪声点特别敏感;k 值选得过大就会使得近邻中包含太多别的类的点。最佳 k 值的估计可以使用交叉验证的方法。通常,使用 k=1会有一个比较好的结果(特别是对于小数据集的情况)。但是,在样本很充足的情况下,选择较大的 k 值可以提高抗噪能力。

    (2)类别决策时的综合方法;
    对目标对象的类别进行决策,最简单的就是使用总体占优方法(简单投票,票数最多的一类胜出)。稍微复杂一点,考虑近邻中每个点与目标对象的距离不同,对决策的份量进行加权考虑。


    image.png
    image.png

    (3)距离测量标准的选择。
    距离测量的标准一般选择欧几里得距离或者曼哈顿距离

    image.png
    image.png

    简单例子

    '''#########################################################################################
    # Name: kNN-test
    # Author: Wenchao Liu
    # Date: 2018-12-22
    # Description: To study the kNN method by using a simple example.
    #                    Windows10, Python3.7
    #########################################################################################'''
    
    from math import sqrt
    
    def calculateDistance(D, z):
        '''计算目标点与训练集中的点的距离,使用欧几里得距离。
          输入(D:训练集数据;z:目标向量)。
          输出(distance:目标点与训练样本点的欧几里得距离)。'''
        distance = []
        for item in D:
            d = 0
            for i in range(len(item)):
                d =  d + (z[i] - item[i])**2
            distance.append(sqrt(d))
        #print(distance)
        return distance
    
    def selectClass(k, L, distance):
        '''类别决策,判定目标属于哪一个类别。
          输入(k:取最近邻点的数量;L:类别标签;distance:距离)。'''
        if(k==1):
            min_index = distance.index(min(distance))
            #print(min_index)
            print('上证指数今天收盘结果是:' + L[min_index])
    
    def main():
        # 训练集。样本数量6;属性4:交点、前一天涨跌(0跌1涨)、振幅(%)、高低开(0低1高)
        D = [[0, 1, 1.33, 1], [7, 1, 0.55, 0], [0, 1, 1.29, 0], [0, 1, 0.75, 0], [0, 0, 0.43, 1], [0, 1, 0.52, 1]]
        # 训练样本的类别标签集
        L = ['涨', '涨', '跌', '跌', '涨', '涨']
        # 测试目标
        z = [0, 0, 1.13, 0]
        k = 1    # 选取最近邻点的数量
        distance = calculateDistance(D, z)
        selectClass(k, L, distance)
    
    if __name__ == '__main__':
        main()
    

    [1] 《数据挖掘10大经典算法》

    相关文章

      网友评论

          本文标题:kNN(k-NearestNeighbor)算法

          本文链接:https://www.haomeiwen.com/subject/eapbkqtx.html