K近邻算法
算法思想
KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:
- 计算测试数据与各个训练数据之间的距离
- 这里距离一般使用欧氏距离或曼哈顿距离:
- 按照距离的递增关系进行排序;
- 选取距离最小的K个点(一般K不超过20);
- 确定前K个点所在类别的出现频率;
- 返回前K个点中出现频率最高的类别作为测试数据的预测分类。
案例
1. 鸢尾花数据集
鸢尾花数据集共收集了三类鸢尾花,即Setosa鸢尾花、Versicolour鸢尾花和Virginica鸢尾花,每一类鸢尾花收集了50条样本记录,共计150条。
数据集包括4个属性,分别为花萼的长、花萼的宽、花瓣的长和花瓣的宽。对花瓣我们可能比较熟悉,花萼是什么呢?花萼是花冠外面的绿色被叶,在花尚未开放时,保护着花蕾。四个属性的单位都是cm
,属于数值变量,四个属性均不存在缺失值的情况,以下是各属性的一些统计值如下:
属性 | 最大值 | 最小值 | 均值 | 方差 |
---|---|---|---|---|
萼长 | 7.9 | 4.3 | 5.84 | 0.83 |
萼宽 | 4.4 | 2.0 | 3.05 | 0.43 |
瓣长 | 6.9 | 1.0 | 3.76 | 1.76 |
瓣宽 | 2.5 | 0.1 | 1.20 | 0.76 |
参考:鸢尾花数据集
# 导入数据库
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 1.导入数据和标签(/目标值)
dataset = datasets.load_iris()
data_X = dataset.data
data_Y = dataset.target
# 2.将数据分成训练和测试集
X_train, X_test, y_train, y_test = train_test_split(
data_X, data_Y, test_size=0.3)
# 3.布置模型
knn = KNeighborsClassifier()
knn.fit(X_train, y_train)
print(knn.predict(X_test))
print(y_test)
参考:莫烦Python
网友评论