KNN,即K nearest neighbor,K近邻算法。KNN的思想非常简单,所需的数学知识较少。比如下图,星星是一个新的样本,要判断星星是属于蓝色的还是黄色的样本分类,就要看它周围的邻居是什么分类。假设K=3,就是看周围三个点的分类,如图,周围有两个红点,一个黄点,应该归类为红色类别。
导入数据:
from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
data = cancer.data
target = cancer.target
import numpy as np
X = np.array([20, 30])
plt.scatter(data[target==1,0], data[target==1,1], alpha = 0.5, color = 'green', label = '1')
plt.scatter(data[target==0,0], data[target==0,1], alpha = 0.5, color = 'orange', label = '0')
plt.plot(X[0], X[1], color = 'purple', marker = '*', ms = 20)
plt.legend(loc = 'upper right')
plt.show()
KNN的计算中,先计算距离,比较常见的是欧拉距离:
也就是两个点(或者多个点)对应的横纵坐标差的平方和,然后开平方。
根据欧拉距离写一个KNN的实现:
def KNN_test(X_train, y_train, test, K):
distance = []
for t in X_train:
d = sqrt(np.sum((t - test)**2))
distance.append(d)
ind = np.argsort(distance)
topK_target = y_train[ind[:K]]
c = Counter(topK_target)
return c.most_common(1)[0][0]
判断刚才的点是属于哪一类:
KNN_test(data[:, :2], target, X, 6)
sklearn中的实现:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=6)
knn.fit(data[:, :2], target)
knn.predict(X.reshape(1, -1))
网友评论