美文网首页
01、KNN算法

01、KNN算法

作者: ltjxwxz | 来源:发表于2020-02-03 21:36 被阅读0次

    一、python手写KNN算法测试案例

    import matplotlib.pyplot as plt
    import numpy as np
    from math import sqrt
    from collections import Counter
    
    # 原始数据
    raw_data_x = [[3.393533211, 2.331273381],
        [3.110073483, 1.781539638],
        [1.343808831, 3.368360954],
        [3.582294042, 4.679179110],
        [2.280362439, 2.866990263],
        [7.423436942, 4.696522875],
        [5.745051997, 3.533989803],
        [9.172168622, 2.511101045],
        [7.792783481, 3.424088941],
        [7.939820817, 0.791637231]]
    raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
    target = [8.093607318, 3.365731514]
    
    # 转成numpy类型矩阵
    X_train = np.array(raw_data_x)
    y_train = np.array(raw_data_y)
    
    # 绘制图像
    plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], color="r")
    plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], color="b")
    plt.scatter(target[0], target[1], color="y")
    plt.show()
    
    def KNN_classify(k, X_train, y_train, x):
        distances = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train]
        nearest_k = np.argsort(distances)[:k]
        topk_y = y_train[nearest_k]
        # 统计元素和元素出现的频次
        votes = Counter(topk_y)
        # most_common(n):找出票数最多的n个元素
        return votes.most_common(1)[0][0]
    
    predict_y = KNN_classify(6, X_train, y_train, target)
    print(predict_y)
    ```

    相关文章

      网友评论

          本文标题:01、KNN算法

          本文链接:https://www.haomeiwen.com/subject/sgokxhtx.html