这个是为了练手写的代码,算法思路简单,是一个很好的练手项目。实现过程涉及的知识点包括数据生成,代码编写以及画图等。
1.原理
knn英文是k-NearestNeighbor,算法思路最简单了,这里边面试经常问k是什么意思,说的抽象一点就是未知样本距离已知样本群中最近的k个样本,通俗一点讲就是离你最近的k个朋友。然后怎么给未知样本分类呢,就是投票法,这k个样本中最多的类别就是这个未知样本的类别。最后一个问题,怎么计算距离,方法很多,我们一般用欧氏距离。
2.python代码
下载的是sklearn自带的鸢尾花数据集。
这里实现的关键是怎么找到离样本点最近的k个点,思路有千千万万种,
有用循环+字典的,
有用list+sort_index的,
我用的是array+argsort,可以说是数据一个格式处理,毫无违和感。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from scipy import stats
import numpy as np
import matplotlib.pyplot as plt
iris_dataset = load_iris()
iris_data = iris_dataset['data']#ndarray
iris_label = iris_dataset['target']
train_data, test_data, train_label, test_label = train_test_split(iris_data, iris_label, test_size=0.3, random_state=7099)
def findk(datawl, datawn, label, k):
dist = np.sqrt(np.sum((datawl - datawn) ** 2, axis=1))
indexk = dist.argsort()
labelo = stats.mode(label[indexk][0:k])[0][0]
return labelo
predicty = []
for i in range(len(test_data)):
tmppre = findk(train_data, test_data[i], train_label, 5)
predicty.append(tmppre)
predicty = np.array(predicty)
accuracy = sum(predicty == test_label)/len(test_label)
print(accuracy)
def plotknn(datawl, datawn, label, k):
plt.style.use('ggplot')
plt.rcParams["font.family"] = 'Arial Unicode MS'
fig = plt.figure(figsize=[4, 8])
xx, yy = np.meshgrid(np.arange(datawl[:, 0].min() - 1, datawl[:, 0].max() + 1, 0.2), \
np.arange(datawl[:, 1].min() - 1, train_data[:, 1].max() + 1, 0.2))
plt.scatter(datawl[:, 0], datawl[:, 1], c=label, s=80, cmap=plt.cm.spring, edgecolors='k')
plt.xlim(datawl[:, 0].min() - 1, datawl[:, 0].max() + 1)
plt.ylim(datawl[:, 1].min() - 1, train_data[:, 1].max() + 1)
plt.scatter(datawn[0], datawn[1], marker='*', c='r', s=200)
plt.text(datawn[0] + 0.1, datawn[1]+0.1, r'聚类结果是%s' % str(findk(datawl, datawn, label, k)))
plt.show()
plotknn(train_data, test_data[0], train_label, 5)
最后看一下图片:
matplotlib作图
网友评论