美文网首页机器学习人工智能每天学点机器学习面试
李航统计学习方法(三)----k近邻算法

李航统计学习方法(三)----k近邻算法

作者: 文子轩 | 来源:发表于2018-02-09 08:49 被阅读18次

    k近邻算法

    给定一个训练数据集,对新的输入实例,在训练数据集中找到跟它最近的k个实例,根据这k个实例的类判断它自己的类(一般采用多数表决的方法)。

    image.png

    k近邻模型

    模型有3个要素——距离度量方法、k值的选择和分类决策规则。

    模型

    当3要素确定的时候,对任何实例(训练或输入),它所属的类都是确定的,相当于将特征空间分为一些子空间。


    image.png

    距离度量
    对n维实数向量空间Rn,经常用Lp距离或曼哈顿Minkowski距离。

    Lp距离定义如下:


    image

    当p=2时,称为欧氏距离:

    image

    当p=1时,称为曼哈顿距离:

    image

    当p=∞,它是各个坐标距离的最大值,即:

    image

    用图表示如下:


    image.png

    k值的选择

    k较小,容易被噪声影响,发生过拟合。

    k较大,较远的训练实例也会对预测起作用,容易发生错误。

    分类决策规则

    使用0-1损失函数衡量,那么误分类率是:

    image

    Nk是近邻集合,要使左边最小,右边的

    image

    必须最大,所以多数表决=经验最小化。

    k近邻法的实现:kd树

    算法核心在于怎么快速搜索k个近邻出来,朴素做法是线性扫描,不可取,这里介绍的方法是kd树。

    构造kd树

    对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:

    找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S',++i,node=current),直到不可再分。


    image.png
       T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
         
        class node:
            def __init__(self, point):
                self.left = None
                self.right = None
                self.point = point
                pass
            
        def median(lst):
            m = len(lst) / 2
            return lst[m], m
         
        def build_kdtree(data, d):
            data = sorted(data, key=lambda x: x[d])
            p, m = median(data)
            tree = node(p)
         
            del data[m]
            print data, p
         
            if m > 0: tree.left = build_kdtree(data[:m], not d)
            if len(data) > 1: tree.right = build_kdtree(data[m:], not d)
            return tree
         
        kd_tree = build_kdtree(T, 0)
        print kd_tree
    

    可视化

    可视化的话则要费点功夫保存中间结果,并恰当地展示出来

        # -*- coding:utf-8 -*-
        # Filename: kdtree.py
        # Author:hankcs
        # Date: 2015/2/4 15:01
        import copy
        import itertools
        from matplotlib import pyplot as plt
        from matplotlib.patches import Rectangle
        from matplotlib import animation
         
        T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
         
         
        def draw_point(data):
            X, Y = [], []
            for p in data:
                X.append(p[0])
                Y.append(p[1])
            plt.plot(X, Y, 'bo')
         
         
        def draw_line(xy_list):
            for xy in xy_list:
                x, y = xy
                plt.plot(x, y, 'g', lw=2)
         
         
        def draw_square(square_list):
            currentAxis = plt.gca()
            colors = itertools.cycle(["r", "b", "g", "c", "m", "y", '#EB70AA', '#0099FF'])
            for square in square_list:
                currentAxis.add_patch(
                    Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
                              color=next(colors)))
         
         
        def median(lst):
            m = len(lst) / 2
            return lst[m], m
         
         
        history_quare = []
         
         
        def build_kdtree(data, d, square):
            history_quare.append(square)
            data = sorted(data, key=lambda x: x[d])
            p, m = median(data)
         
            del data[m]
            print data, p
         
            if m >= 0:
                sub_square = copy.deepcopy(square)
                if d == 0:
                    sub_square[1][0] = p[0]
                else:
                    sub_square[1][1] = p[1]
                history_quare.append(sub_square)
                if m > 0: build_kdtree(data[:m], not d, sub_square)
            if len(data) > 1:
                sub_square = copy.deepcopy(square)
                if d == 0:
                    sub_square[0][0] = p[0]
                else:
                    sub_square[0][1] = p[1]
                build_kdtree(data[m:], not d, sub_square)
         
         
        build_kdtree(T, 0, [[0, 0], [10, 10]])
        print history_quare
         
         
        # draw an animation to show how it works, the data comes from history
        # first set up the figure, the axis, and the plot element we want to animate
        fig = plt.figure()
        ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
        line, = ax.plot([], [], 'g', lw=2)
        label = ax.text([], [], '')
         
        # initialization function: plot the background of each frame
        def init():
            plt.axis([0, 10, 0, 10])
            plt.grid(True)
            plt.xlabel('x_1')
            plt.ylabel('x_2')
            plt.title('build kd tree (www.hankcs.com)')
            draw_point(T)
         
         
        currentAxis = plt.gca()
        colors = itertools.cycle(["#FF6633", "g", "#3366FF", "c", "m", "y", '#EB70AA', '#0099FF', '#66FFFF'])
         
        # animation function.  this is called sequentially
        def animate(i):
            square = history_quare[i]
            currentAxis.add_patch(
                Rectangle((square[0][0], square[0][1]), square[1][0] - square[0][0], square[1][1] - square[0][1],
                          color=next(colors)))
            return
         
        # call the animator.  blit=true means only re-draw the parts that have changed.
        anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history_quare), interval=1000, repeat=False,
                                       blit=False)
        plt.show()
        anim.save('kdtree_build.gif', fps=2, writer='imagemagick')
    
    image.png

    搜索kd树

    上面的代码其实并没有搜索kd树,现在来实现搜索。

    搜索跟二叉树一样来,是一个递归的过程。先找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。以最近邻为例,实现如下(本实现由于测试数据简单,没有做超球体与超立体相交的逻辑):

        # -*- coding:utf-8 -*-
        # Filename: search_kdtree.py
        # Author:hankcs
        # Date: 2015/2/4 15:01
         
        T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
         
         
        class node:
            def __init__(self, point):
                self.left = None
                self.right = None
                self.point = point
                self.parent = None
                pass
         
            def set_left(self, left):
                if left == None: pass
                left.parent = self
                self.left = left
         
            def set_right(self, right):
                if right == None: pass
                right.parent = self
                self.right = right
         
         
        def median(lst):
            m = len(lst) / 2
            return lst[m], m
         
         
        def build_kdtree(data, d):
            data = sorted(data, key=lambda x: x[d])
            p, m = median(data)
            tree = node(p)
         
            del data[m]
         
            if m > 0: tree.set_left(build_kdtree(data[:m], not d))
            if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
            return tree
         
         
        def distance(a, b):
            print a, b
            return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
         
         
        def search_kdtree(tree, d, target):
            if target[d] < tree.point[d]:
                if tree.left != None:
                    return search_kdtree(tree.left, not d, target)
            else:
                if tree.right != None:
                    return search_kdtree(tree.right, not d, target)
         
            def update_best(t, best):
                if t == None: return
                t = t.point
                d = distance(t, target)
                if d < best[1]:
                    best[1] = d
                    best[0] = t
         
         
            best = [tree.point, 100000.0]
            while (tree.parent != None):
                update_best(tree.parent.left, best)
                update_best(tree.parent.right, best)
                tree = tree.parent
            return best[0]
         
         
        kd_tree = build_kdtree(T, 0)
        print search_kdtree(kd_tree, 0, [9, 4])
    

    输出

    [8, 1] [9, 4]
    [5, 4] [9, 4]  
    [9, 6] [9, 4]
    [9, 6]
    

    相关文章

      网友评论

        本文标题:李航统计学习方法(三)----k近邻算法

        本文链接:https://www.haomeiwen.com/subject/oqfftftx.html