美文网首页
10行代码实现kNN算法

10行代码实现kNN算法

作者: uglybachelor | 来源:发表于2018-09-26 23:41 被阅读20次

使用 numpy 库,只需 10 行简单的代码就能实现 k 近邻算法。

算法逻辑

对要分类的点(X)进行下列运算:

  1. 计算 X 与已知分类的所有点的距离(欧氏距离);
  2. 距离按照递增排序;
  3. 选取距离最小的 k 个点;
  4. 计算这 k 个点中,每个分类出现的频率;
  5. 取频率最高的分类为预测分类。

代码实现

def kNN(X, dataset, labels, k):
    # 计算X与所有其他点的差值
    diff = np.tile(X, (dataset.shape[0], 1)) - dataset
    # 计算欧氏距离
    distances = ((diff**2).sum(axis=1))**0.5
    # 排序
    sorted_idx = distances.argsort()
    class_cnt = {}
    for i in range(k):
        label = labels[sorted_idx[i]]
        class_cnt[label] =class_cnt.get(label, 0) + 1
    # 找出频率最高的分类作为预测结果
    sorted_class = sorted(class_cnt.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class[0][0]

enjoy!

相关文章

网友评论

      本文标题:10行代码实现kNN算法

      本文链接:https://www.haomeiwen.com/subject/jalsoftx.html