美文网首页
纯numpy动手写K近邻模型

纯numpy动手写K近邻模型

作者: _龙雀 | 来源:发表于2019-05-29 20:06 被阅读0次

简单清晰的的基于线性扫描实现的KNN模型

import numpy as np

class KNN():
    def __init__(self):
        self.k = 1
        self.distances = []
        
    def distance(self, vector1, vector2):
        return np.sqrt(np.sum(np.square(vector1-vector2)))
    
    def predict(self, X, y, x):
        for i in range(X.shape[0]):
            self.distances.append(self.distance(x, X[i]))
        topK = np.array(self.distances).argsort()[0:self.k] #获取前k个最小值的索引
        #统计正负样本数目 正负样本标签分别为1 0
        pos_label = 0
        neg_label = 0
        for i in topK: #遍历索引所在位置的标签
            if y[i] == 1:
                pos_label += 1
            else:
                neg_label += 1
        if pos_label >= neg_label:
            return 1
        else:
            return 0
        
if __name__ == '__main__':
    X = np.array([[5,4],[9,6],[4,7],[2,3],[8,1],[7,2]])
    y = np.array([1,1,1,0,0,0])
    knn = KNN()
    s = np.array([[5,3]])
    s_pred = knn.predict(X, y, s)
    print("My KNN model predict the label is d%", s_pred)
    
    #调用sklearn
    from sklearn.neighbors import KNeighborsClassifier
    model = KNeighborsClassifier()
    model.fit(X,y)
    s_pred = model.predict(s)
    print("sklearn KNN model predict the label is d%", s_pred)

相关文章

网友评论

      本文标题:纯numpy动手写K近邻模型

      本文链接:https://www.haomeiwen.com/subject/ihyutctx.html