美文网首页
knn算法的原理及应用

knn算法的原理及应用

作者: 不分享的知识毫无意义 | 来源:发表于2020-01-12 19:42 被阅读0次

这个是为了练手写的代码,算法思路简单,是一个很好的练手项目。实现过程涉及的知识点包括数据生成,代码编写以及画图等。

1.原理

knn英文是k-NearestNeighbor,算法思路最简单了,这里边面试经常问k是什么意思,说的抽象一点就是未知样本距离已知样本群中最近的k个样本,通俗一点讲就是离你最近的k个朋友。然后怎么给未知样本分类呢,就是投票法,这k个样本中最多的类别就是这个未知样本的类别。最后一个问题,怎么计算距离,方法很多,我们一般用欧氏距离。

2.python代码

下载的是sklearn自带的鸢尾花数据集。
这里实现的关键是怎么找到离样本点最近的k个点,思路有千千万万种,
有用循环+字典的,
有用list+sort_index的,
我用的是array+argsort,可以说是数据一个格式处理,毫无违和感。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from scipy import stats
import numpy as np
import matplotlib.pyplot as plt

iris_dataset = load_iris()
iris_data = iris_dataset['data']#ndarray
iris_label = iris_dataset['target']
train_data, test_data, train_label,  test_label = train_test_split(iris_data, iris_label, test_size=0.3, random_state=7099)

def findk(datawl, datawn, label, k):
    dist = np.sqrt(np.sum((datawl - datawn) ** 2, axis=1))
    indexk = dist.argsort()
    labelo = stats.mode(label[indexk][0:k])[0][0]
    return labelo
predicty = []
for i in range(len(test_data)):
    tmppre = findk(train_data, test_data[i], train_label, 5)
    predicty.append(tmppre)
predicty = np.array(predicty)
accuracy = sum(predicty == test_label)/len(test_label)
print(accuracy)

def plotknn(datawl, datawn, label, k):
    plt.style.use('ggplot')
    plt.rcParams["font.family"] = 'Arial Unicode MS'
    fig = plt.figure(figsize=[4, 8])
    xx, yy = np.meshgrid(np.arange(datawl[:, 0].min() - 1, datawl[:, 0].max() + 1, 0.2), \
                         np.arange(datawl[:, 1].min() - 1, train_data[:, 1].max() + 1, 0.2))
    plt.scatter(datawl[:, 0], datawl[:, 1], c=label, s=80, cmap=plt.cm.spring, edgecolors='k')
    plt.xlim(datawl[:, 0].min() - 1, datawl[:, 0].max() + 1)
    plt.ylim(datawl[:, 1].min() - 1, train_data[:, 1].max() + 1)
    plt.scatter(datawn[0], datawn[1], marker='*', c='r', s=200)
    plt.text(datawn[0] + 0.1, datawn[1]+0.1, r'聚类结果是%s' % str(findk(datawl, datawn, label, k)))
    plt.show()

plotknn(train_data, test_data[0], train_label, 5)

最后看一下图片:


matplotlib作图

相关文章

网友评论

      本文标题:knn算法的原理及应用

      本文链接:https://www.haomeiwen.com/subject/kyeiactx.html