美文网首页
手写KNN实践

手写KNN实践

作者: A_You | 来源:发表于2019-12-28 18:40 被阅读0次

手写代码是算法最好的实现机制:

from sklearn import datasets
from collections import Counter  # 为了做投票
from sklearn.model_selection import train_test_split
import numpy as np

# 导入iris数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=2003)

def euc_dis(instance1, instance2):
    """
    计算两个样本instance1和instance2之间的欧式距离
    instance1: 第一个样本, array型
    instance2: 第二个样本, array型
    """
    # 注意运算符的优先级
    dist = np.sqrt(sum((instance1 - instance2) ** 2))
    return dist


def knn_classify(X, y, testInstance, k):
    """
    给定一个测试数据testInstance, 通过KNN算法来预测它的标签。
    X: 训练数据的特征
    y: 训练数据的标签
    testInstance: 测试数据,这里假定一个测试数据 array型
    k: 选择多少个neighbors?
    """
    # TODO  返回testInstance的预测标签 = {0,1,2}
    y_distances=[euc_dis(x,testInstance) for x in X]
    kn = np.argsort(y_distances)[:k]
    count = Counter(y[kn])
    return count.most_common()[0][0]

# 预测结果。
predictions = [knn_classify(X_train, y_train, data, 3) for data in X_test]
correct = np.count_nonzero((predictions==y_test)==True)
print ("Accuracy is: %.3f" %(correct/len(X_test)))
不同K下的校验结果:
k = 3; Accuracy is:0.921
k = 5; Accuracy is:0.895
k = 7; Accuracy is:0.921

相关文章

网友评论

      本文标题:手写KNN实践

      本文链接:https://www.haomeiwen.com/subject/gynsoctx.html