美文网首页技术类文章收集大杂烩Python资源收集
python机器学习——KNN算法简单入门(真的很简单!)

python机器学习——KNN算法简单入门(真的很简单!)

作者: AlanLau | 来源:发表于2017-04-30 00:06 被阅读2653次

    所有代码请移步GitHub——kNNbyPython

    很多人在第一次听到机器学习的时候都不知所措,无从下手。起初我也是这样的,各种看别人的博客,吴恩达的课程也死磕,但效果不佳。后来发现一个神奇的网站k-近邻算法实现手写数字识别系统--《机器学习实战 》,跟着过了一遍之后感觉还不错,也顺便买了《机器学习实战》这本书,接着就正式入坑机器学习。
    KNN算法应该是机器学习中最简单的算法之一,作为机器学习的入门是个非常不错的选择。

    KNN算法思路

    KNN(K-Nearest Neighbor)算法的理论基础网上一查一大把,我这里就不赘述,这里我讲自己的理解。

    KNN算法属于机器学习中的监督算法,主要用于分类。

    首先,在二维坐标轴中,有四个点,分别是a1(1,1),a2(1,2),b1(3,3),b2(3,4)。其中,a1,a2为A类,b1,b2为B类
    这里用matplotlib实现一下这四个点,更加直观点。

    实现这张图的代码,感兴趣的可以看一下。

    # -*- coding: utf-8 -*-
    # @Date     : 2017-04-28 16:52:44
    # @Author   : Alan Lau (rlalan@outlook.com)
    # @Language : Python3.5
    
    from matplotlib import pyplot as plt
    import numpy as np
    
    # 定义四个点的坐标
    a1 = np.array([1, 1])
    a2 = np.array([1, 2])
    b1 = np.array([3, 3])
    b2 = np.array([3, 4])
    
    # 四个点坐标分别赋值给X,Y
    X1, Y1 = a1
    X2, Y2 = a2
    X3, Y3 = b1
    X4, Y4 = b2
    
    plt.title('show data')
    plt.scatter(X1, Y1, color="blue", label="a1")
    plt.scatter(X2, Y2, color="blue", label="a2")
    plt.scatter(X3, Y3, color="red", label="b1")
    plt.scatter(X4, Y4, color="red", label="b2")
    plt.legend(loc='upper left')
    
    plt.annotate(r'a1(1,1)', xy=(X1, Y1), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    plt.annotate(r'a2(1,2)', xy=(X2, Y2), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    plt.annotate(r'b1(3,3)', xy=(X3, Y3), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    plt.annotate(r'b2(3,4)', xy=(X4, Y4), xycoords='data', xytext=(+10, +30), textcoords='offset points', fontsize=16, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    plt.show()
    

    然后,问题出现了,现在突然冒出个c(2,1)


    我现在想知道的是,c(2,1)这个点,在AB两个类中是属于A类,还是数据B类。

    怎么做?
    1.计算c和其余所有点的距离。
    2.将计算出的距离集合进行升序排序(即距离最短的排列在前面)。
    3.获得距离集合降序排序的前k个距离。
    4.统计出在前k个距离中,出现频次最多的类别。
    然后我们把已经知道分类的四个点a1,a2,b1,b3称为训练数据,把未知类别的c称为测试数据。

    这里的k取值一般为小于等于20的常数,具体的取值,看不同的样本。同样,如何确定k的值,获得最佳的计算结果,也是kNN算法的一个难点。

    现在跟着上面的例子走一遍,这里k取3(训练数据才4个,最大只能取3)。

    1.计算c和其余所有点的距离

    计算距离的方法我这里使用欧式距离,具体python代码可以参考我的另一篇博文 python实现各种距离,同样,在众多计算距离的方法中,确定使用kNN算法时用哪个距离算法也是该算法的难点之一。

    此图代码:

    # 如想运行,请拼接上一段代码
    import math
    
    def Euclidean(vec1, vec2):
        npvec1, npvec2 = np.array(vec1), np.array(vec2)
        return math.sqrt(((npvec1-npvec2)**2).sum())
        
    # 显示距离
    def show_distance(exit_point, c):
        line_point = np.array([exit_point, c])
        x = (line_point.T)[0]
        y = (line_point.T)[1]
        o_dis = round(Euclidean(exit_point, c), 2)  # 计算距离
        mi_x, mi_y = (exit_point+c)/2  # 计算中点位置,来显示“distance=xx”这个标签
        plt.annotate('distance=%s' % str(o_dis), xy=(mi_x, mi_y), xycoords='data', xytext=(+10, 0), textcoords='offset points', fontsize=10, arrowprops=dict(arrowstyle="-", connectionstyle="arc3,rad=.2"))
        return plt.plot(x, y, linestyle="--", color='black', lw=1)
    
    show_distance(a1, c)
    show_distance(a2, c)
    show_distance(b1, c)
    show_distance(b2, c)
    plt.show()
    

    代码的注释中怎么引用自己写的包和.py,看一参考我的博客python中import自己写的.py

    欧式距离计算方法

    def Euclidean(vec1, vec2):
        npvec1, npvec2 = np.array(vec1), np.array(vec2)
        return math.sqrt(((npvec1-npvec2)**2).sum())
    
    2.将计算出的距离集合进行升序排序(即距离最短的排列在前面)

    |升序序号|点标签|标签所属类别|点坐标|与c点距离|
    | ------------- |:-------------: |:-------------:| -----:|
    | 1 | a1 |A | (1,1) |1.0|
    | 2 | a2 |A | (1,2) |1.41|
    | 3 | b1 |B | (3,3) |2.24|
    | 4 | b2 |B | (3,4) |3.16|

    3.获得距离集合升序排序的前k个距离

    k取值为3,因此保留升序排序前三的距离

    |升序序号|点标签|标签所属类别|点坐标|与c点距离|
    | ------------- |:-------------: |:-------------:| -----:|
    | 1 | a1 |A | (1,1) |1.0|
    | 2 | a2 |A | (1,2) |1.41|
    | 3 | b1 |B | (3,3) |2.24|

    4.统计出在前k个距离中,出现频次最多的类别

    肉眼直接看出,频次最多的类别是A。因此,c点属于A类。

    5.总结

    在上面这个例子中我用了四个点,即四个向量,同时为了方便理解,我使用的是二维坐标平面。但是在真正的kNN实战中,则涉及的训练数量是非常庞大的,同样,也不会单单局限于二维,而是多维向量。但是,其实现方法都是相同的。当然,我上面举的例子是不能用来实际使用的,因为训练数据太少。
    上述例子的所有代码,感兴趣可以自己过一遍:

    # -*- coding: utf-8 -*-
    # @Date     : 2017-04-28 16:52:44
    # @Author   : Alan Lau (rlalan@outlook.com)
    # @Language : Python3.5
    
    from matplotlib import pyplot as plt
    import numpy as np
    import math
    
    
    # 定义四个点的坐标
    a1 = np.array([1, 1])
    a2 = np.array([1, 2])
    b1 = np.array([3, 3])
    b2 = np.array([3, 4])
    c = np.array([2, 1])
    
    # 四个点坐标分别赋值给X,Y
    X1, Y1 = a1
    X2, Y2 = a2
    X3, Y3 = b1
    X4, Y4 = b2
    X5, Y5 = c
    
    plt.title('show data')
    plt.scatter(X1, Y1, color="blue", label="a1")
    plt.scatter(X2, Y2, color="blue", label="a2")
    plt.scatter(X3, Y3, color="red", label="b1")
    plt.scatter(X4, Y4, color="red", label="b2")
    plt.scatter(X5, Y5, color="yellow", label="c")
    
    plt.annotate(r'a1(1,1)', xy=(X1, Y1), xycoords='data', xytext=(+10, +20), textcoords='offset points',fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    plt.annotate(r'a2(1,2)', xy=(X2, Y2), xycoords='data', xytext=(+10, +20), textcoords='offset points',fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    plt.annotate(r'b1(3,3)', xy=(X3, Y3), xycoords='data', xytext=(+10, +20), textcoords='offset points',fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    plt.annotate(r'b2(3,4)', xy=(X4, Y4), xycoords='data', xytext=(+10, +20), textcoords='offset points', fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    
    plt.annotate(r'c(2,1)', xy=(X5, Y5), xycoords='data', xytext=(+30, 0), textcoords='offset points', fontsize=12, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))
    
    
    def Euclidean(vec1, vec2):
        npvec1, npvec2 = np.array(vec1), np.array(vec2)
        return math.sqrt(((npvec1-npvec2)**2).sum())
    
    # 显示距离
    def show_distance(exit_point, c):
        line_point = np.array([exit_point, c])
        x = (line_point.T)[0]
        y = (line_point.T)[1]
        o_dis = round(Euclidean(exit_point, c), 2)  # 计算距离
        mi_x, mi_y = (exit_point+c)/2  # 计算中点位置,来显示“distance=xx”这个标签
        plt.annotate('distance=%s' % str(o_dis), xy=(mi_x, mi_y), xycoords='data', xytext=(+10, 0), textcoords='offset points', fontsize=10, arrowprops=dict(arrowstyle="-", connectionstyle="arc3,rad=.2"))
        return plt.plot(x, y, linestyle="--", color='black', lw=1)
    
    show_distance(a1, c)
    show_distance(a2, c)
    show_distance(b1, c)
    show_distance(b2, c)
    plt.show()
    

    实战

    实战这里使用k-近邻算法实现手写数字识别系统--《机器学习实战 》中的数据进行,但是本人的代码与网站提供的代码有差异。

    准备数据

    在使用数据之前,我先对网站提供的数据进行预处理,方便使用numpy读取。
    网站提供数据:


    处理后的数据:


    实际上就是在数字之间加上空格,方便numpy识别并分割数据。
    数据预处理的代码:

    # -*- coding: utf-8 -*-
    # @Date    : 2017-04-03 16:04:19
    # @Author  : Alan Lau (rlalan@outlook.com)
    
    def fwalker(path):
        fileArray = []
        for root, dirs, files in os.walk(path):
            for fn in files:
            eachpath = str(root+'\\'+fn)
            fileArray.append(eachpath)
        return fileArray
    
    def writetxt(path, content, code):
        with open(path, 'a', encoding=code)as f:
            f.write(content)
        return path+' is ok!'
    
    def readtxt(path, encoding):
        with open(path, 'r', encoding=encoding) as f:
            lines = f.readlines()
        return lines
    
    def buildfile(echkeyfile):
        if os.path.exists(echkeyfile):
            #创建前先判断是否存在文件夹,if存在则删除
            shutil.rmtree(echkeyfile)
            os.makedirs(echkeyfile)
        else:
            os.makedirs(echkeyfile)#else则创建语句
    return echkeyfile
    
    def change_data(files, inputpath):
        trainpath = buildfile(inputpath+'\\'+'trainingDigits')
        testpath = buildfile(inputpath+'\\'+'testDigits')
        for file in files:
            ech_name = (file.split('\\'))[-2:]
            new_path = inputpath+'\\'+'\\'.join(ech_name)
            ech_content = readtxt(file, 'utf8')
            new_content = []
            for ech_line in ech_content:
                line_ary = list(ech_line.replace('\n', '').replace('\r', ''))
                new_content.append(' '.join(line_ary))
            print(writetxt(new_path, '\n'.join(new_content), 'utf8'))
    
    
    def main():
        datapath =r'..\lab3_0930\digits'
        inputpath = buildfile(r'..\lab3_0930\input_digits')
        files = fwalker(datapath)
        change_data(files, inputpath)
    
    if __name__ == '__main__':
        main()
    

    实现代码

    教程网站中利用list下标索引将标签和向量进行对应,而我使用将每一个标签和向量放到分别一个list中,再将这些list放到一个list内,类似于实现字典。如[[label1,vector1],[label2,vector2],[label3,vector3],...]

    # -*- coding: utf-8 -*-
    # @Date    : 2017-04-03 15:47:04
    # @Author  : Alan Lau (rlalan@outlook.com)
    
    import os
    import math
    import collections
    import numpy as np
    
    def Euclidean(vec1, vec2):
        npvec1, npvec2 = np.array(vec1), np.array(vec2)
        return math.sqrt(((npvec1-npvec2)**2).sum())
        
    def fwalker(path):
        fileArray = []
        for root, dirs, files in os.walk(path):
            for fn in files:
            eachpath = str(root+'\\'+fn)
            fileArray.append(eachpath)
        return fileArray
    
    def orderdic(dic, reverse):
        ordered_list = sorted(
            dic.items(), key=lambda item: item[1], reverse=reverse)
        return ordered_list
    
    def get_data(data_path):
        label_vec = []
        files = fwalker(data_path)
        for file in files:
            ech_label_vec = []
            ech_label = int((file.split('\\'))[-1][0])# 获取每个向量的标签
            ech_vec = ((np.loadtxt(file)).ravel())# 获取每个文件的向量
            ech_label_vec.append(ech_label) # 将一个文件夹的标签和向量放到同一个list内
            ech_label_vec.append(ech_vec) # 将一个文件夹的标签和向量放到同一个list内,目的是将标签和向量对应起来,类似于字典,这里不直接用字典因为字典的键(key)不可重复。
            label_vec.append(ech_label_vec) # 再将所有的标签和向量存入一个list内,构成二维数组
        return label_vec
    
    
    def find_label(train_vec_list, vec, k):
        get_label_list = []
        for ech_trainlabel_vec in train_vec_list:
            ech_label_distance = []
            train_label, train_vec = ech_trainlabel_vec[0], ech_trainlabel_vec[1]
            vec_distance = Euclidean(train_vec, vec)# 计算距离
            ech_label_distance.append(train_label)
            ech_label_distance.append(vec_distance)# 将距离和标签对应存入list
            get_label_list.append(ech_label_distance)
        result_k = np.array(get_label_list)
        order_distance = (result_k.T)[1].argsort()# 对距离进行排序
        order = np.array((result_k[order_distance].T)[0])
        top_k = np.array(order[:k], dtype=int) # 获取前k距离和标签
        find_label = orderdic(collections.Counter(top_k), True)[0][0]# 统计在前k排名中标签出现频次
        return find_label
    
    
    def classify(train_vec_list, test_vec_list, k):
        error_counter = 0 #计数器,计算错误率
        for ech_label_vec in test_vec_list:
            label, vec = ech_label_vec[0], ech_label_vec[1]
            get_label = find_label(train_vec_list, vec, k) # 获得学习得到的标签
            print('Original label is:'+str(label) +
                  ', kNN label is:'+str(get_label))
            if str(label) != str(get_label):
                error_counter += 1
            else:
                continue
        true_probability = str(round((1-error_counter/len(test_vec_list))*100, 2))+'%'
        print('Correct probability:'+true_probability)
    
    
    def main():
        k = 3
        train_data_path =r'..\lab3_0930\input_digits\trainingDigits'
        test_data_path =r'..\lab3_0930\input_digits\testDigits'
        train_vec_list = get_data(train_data_path)
        test_vec_list = get_data(test_data_path)
        classify(train_vec_list, test_vec_list, k)
    
    if __name__ == '__main__':
        main()
    
    

    正确率

    相关文章

      网友评论

        本文标题:python机器学习——KNN算法简单入门(真的很简单!)

        本文链接:https://www.haomeiwen.com/subject/ofjozttx.html