美文网首页
2019-03-05深度学习——搭建一个简单的KnnClassi

2019-03-05深度学习——搭建一个简单的KnnClassi

作者: Hie_9e55 | 来源:发表于2019-03-05 21:19 被阅读0次

    KNN思想

    NN选出与目标图片距离(范数)最近的一张图片
    KNN选出与目标图片距离(范数)最近的K张图片,并统计K张图中出现次数最多的类型,即为预测类型

    代码实现

    分为两个部分
    第一部分是knn模型
    第二部分是模型的使用

    1. KNN.py
    # KNN.py
    # 导入所需要的库
    # 这里我们需要使用numpy库进行矩阵运算
    # 使用collections中的Counter
    import numpy as np
    from collections import Counter
    
    class KNearestNeighbor:
    
        def __init__(self, k = 7):
            self.k = k
    
        # 训练模型,KNN只是简单的导入即可,因为K是一个超参数,X是数据,n*3072,Y是数据标签,n*1
        def train(self, X, y):
            self.Xtr = X
            self.ytr = y
    
        # 使用模型进行预测,X是test集的数据
        def predict(self, X):
            num_test = X.shape[0]# test数据个数
            Ypred = np.zeros((num_test, len(self.k)))# 初始化预测结果
            
            for i in range(num_test):# 每次迭代一张图片
    
                distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)# 计算范数一
                temp = np.argsort(distances)# numpy.argsort()函数返回的是数组值从小到大的索引值
                
                index = []
                for j in range(len(self.k)):
                    # 取出前k个下标
                    id = temp[0:self.k[j]]
                    # 取出k个下标所对应的类型
                    temp_y = np.array(self.ytr)[id]
                
                    # 取出k个下标中出现次数最多的label
                    index.append(Counter(temp_y).most_common(1)[0][0])
    
                print(np.array(index), i)
    
                Ypred[i] = np.array(index)# 记录第i张图片在进行knn之后的label
    
            return Ypred
    
    1. runKNN.py
    # runKNN.py
    # 导入所需要的库
    # pickle用于解压数据
    # matplotlib用于绘图
    # numpy用于矩阵运算
    # KNN用于预测label
    import pickle
    from matplotlib import pyplot as plt
    import numpy as np
    from KNN import KNearestNeighbor
    
    # 数据地址
    filename1 = 'D:/Download/cifar-10-batches-py/data_batch_1'
    filename2 = 'D:/Download/cifar-10-batches-py/data_batch_2'
    filename3 = 'D:/Download/cifar-10-batches-py/data_batch_3'
    filename4 = 'D:/Download/cifar-10-batches-py/data_batch_4'
    filename5 = 'D:/Download/cifar-10-batches-py/data_batch_5'
    filename_test = 'D:/Download/cifar-10-batches-py/test_batch'
    
    # 定义导入数据的函数
    def load_file(filename):
        with open(filename, 'rb') as fo:
            data = pickle.load(fo, encoding='latin1')
        return data
    
    # 使用pickle暴力导入数据  
    data = []
    data.append(load_file(filename1))
    data.append(load_file(filename2))
    data.append(load_file(filename3))
    data.append(load_file(filename4))
    data.append(load_file(filename5))
    test_batch = load_file(filename_test)
    
    # 作业要求的K值
    k = [1,3,5,7,9]
    
    # 初始化几个会用到的list
    result = []
    validation = []
    
    # 建立模型
    net = KNearestNeighbor(k)
    
    # 5个batch分别迭代
    for i in range(5):
        # 训练
        net.train(data[i]['data'], data[i]['labels'])
        # 预测
        result.append(net.predict(test_batch['data']))
        # 计算预测结果与实际label之间的误差
        temp = result[i] - np.array(test_batch['labels']).reshape((np.array(test_batch['labels']).shape[0],1))
        temp[temp != 0] = 1
        # 计算准确度
        validation.append(1 - (sum(abs(temp)) / test_batch['data'].shape[0]))
    
        print('batch', i, '', validation[i])
    
    print(validation)
    
    # 绘图
    plt.title("Cross Validation") 
    plt.xlabel("k") 
    plt.ylabel("validation") 
    plt.axis([0, 10, 0.2, 0.26])
    validation = np.matrix(validation).T
    ave = np.sum(validation, axis = 1) / len(k)
    plt.plot(k, ave)
    fig = plt.plot(k, validation, 'ro')
    plt.setp(fig, color='b')
    plt.savefig('fig')
    plt.show()
    
    1. KNN结果
      可以看到当k=7的时候,准确度最高(其实也高不到哪儿去)


      fig.png

    相关文章

      网友评论

          本文标题:2019-03-05深度学习——搭建一个简单的KnnClassi

          本文链接:https://www.haomeiwen.com/subject/nndauqtx.html