美文网首页
K近邻算法

K近邻算法

作者: 伏牛山忍者 | 来源:发表于2019-03-20 17:36 被阅读0次

    算法原理

    首先k紧邻算法是分类算法, 用来判断是属于哪一类别的。
    给若干个有标记的样本, 例如下图:


    图1.jpg

    然后给一个未标记测试样本,通过计算测试样本距离所有样本的距离, 拿出 最小的K 个样本, 判断K个样本中哪个类别最多,就是判定该测试样本属于该类别。


    2.jpg

    优点:算法简单,精度较高,对异常值不敏感, 无数据输入假定
    缺点:对每个数据集的每个数据特征都要计算距离值,计算复杂度高, 空间复杂度高, 所以一般使用小数据集

    算法封装 (python)

    # -*-coding:utf-8-*-
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    def load_data():
        data = np.array([
            [0.1, 0.1, -1],
            [0.2, 0.1, -1],
            [0.3, 0.2, -1],
            [0.2, 0.4, -1],
            [0.3, 0.4, -1],
            [1.3, 1.0, 1],
            [1.2, 0.9, 1],
            [1.1, 1.4, 1],
            [1.3, 1.6, 1],
            [1.5, 1.4, 1],
        ])
        return data
    
    
    def classify(data, test, k=3):
        """
        K紧邻判别方法
        :param data:
        :param test:
        :param k:
        :return:
        """
        distance = np.sum(np.square(data[:, :-1] - test), axis=1)
        k_orders = np.argsort(distance)[:k]
        result = np.sum(data[k_orders, -1])
        if result > 0:
            return 1
        else:
            return -1
    
    
    data = load_data()
    test = [0.8, 0.5]
    rel = classify(data, test, k=5)
    print(rel)
    
    # 可视化
    negative = data[data[:, -1] == -1][:-1]
    positive = data[data[:, -1] == 1][:-1]
    plt.scatter(negative[:, 0], negative[:, 1])
    plt.scatter(positive[:, 0], positive[:, 1])
    plt.scatter(test[0], test[1])
    plt.show()
    
    

    相关文章

      网友评论

          本文标题:K近邻算法

          本文链接:https://www.haomeiwen.com/subject/yoktvqtx.html