2 Kd树的构造与搜索

作者: 壮少Bryant | 来源:发表于2019-05-24 00:23 被阅读1次

1 KD-Tree

实现kNN算法时,最简单的实现方法就是线性扫描,正如我们上一章节内容介绍的一样->K近邻算法,需要计算输入实例与每一个训练样本的距离。当训练集很大时,会非常耗时。

为了提高kNN搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,KD-Tree就是其中的一种方法。

kd树是一个二叉树结构,相当于不断的用垂线将k维空间进行切分,构成一系列的k维超矩形区域。

2 如何构造KD-Tree

2.1 KD-Tree算法如下:

K维空间数据集 T =\{x_1,x_2,...,x_N \}
其中x_i = (x_i^{(1)},x_i^{(2)},...,x_i^{(k)})^T, i = 1,2,3,...,N

  1. 构造根节点
    选择x^{(1)}为坐标轴,将T中所有实例以x^{(1)}坐标为中位数,垂直x^{(1)}轴切成两个矩形,由根节点生成深度为1的左、右两个子节点:左子节点对应的坐标都小于切分点,右子节点坐标都大于切分点坐标。

  2. 重复:对深度为j的节点,选择x^{(l)}为切分的坐标轴, {l = j (mod\ k)+1},以该节点再次将矩形区域切分为两个子区域。

  3. 直到两个子区域没有实力存在时停止,从而形成KD-Tree的区域划分。

2.2 举例说明KD-Tree构造

随机生成 13 个点作为我们的数据集

13个随机点分布

首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标

根结点

并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左分支,x坐标大于 6.27 的点用于构建右分支。

在下一步中r = 0 + 1 =1\ mod \ 2 ,对应 y 轴,左右两边再按照 y 轴的排序进行切分,中位点记载于左右枝的节点。得到下面的树,左边的 x 是指这该层的节点都是沿 x 轴进行分割的。

空间的切分如下

下一步中r=1+1=0 \ mod \ 2,对应 x 轴,所以下面再按照 x 坐标进行排序和切分,有

最后只剩下了叶子结点,就此完成了 kd 树的构造。

2.3 构造代码

class Node:
    def __init__(self, data, depth=0, lchild=None, rchild=None):
        self.data = data     # 此结点
        self.depth = depth   # 树的深度
        self.lchild = lchild # 左子结点
        self.rchild = rchild # 右子节点


class KdTree:
    def __init__(self):
        self.KdTree = None
        self.n = 0
        self.nearest = None

    def create(self, dataSet, depth=0):
        """KD-Tree创建过程"""
        if len(dataSet) > 0:
            m, n = np.shape(dataSet)
            self.n = n - 1
            # 按照哪个维度进行分割,比如0:x轴,1:y轴
            axis = depth % self.n
            # 中位数
            mid = int(m / 2)
            # 按照第几个维度(列)进行排序
            dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
            # KD结点为中位数的结点,树深度为depth
            node = Node(dataSetcopy[mid], depth)
            if depth == 0:
                self.KdTree = node
            # 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
            node.lchild = self.create(dataSetcopy[:mid], depth+1)
            node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
            return node
        return None

3 搜索KD-Tree

输入:已构造的kd树,目标点x
输出:x的k个最近邻集合L

3.1 KD-Tree的最近邻搜索算法

  1. 从根结点出发,递归向下访问KD-Tree,如果目标点x当前维小于切分点坐标,移动到左子节点,否则右子节点,直到子节点为叶子结点为止。

  2. 以此叶子结点为最近邻的点,插入到集合L中

  3. 递归向上回退,在这个节点进行以下操作:

  • a 如果该节点比L里的点更近,则替换集合L中距离最大的点。
  • b 目标点到此节点的分割线垂直的距离为d,判断集合L中距离最大的点与 d 相比较,如果比d大,说明d的另一侧区域中有可能有比集合L中距离要小,因此需要查看d的左右两个子节点的距离。
    如果集合L中距离最大的点比 d小,那说明另一侧区域的点距离目标点的距离都比d大,因此不用查找了,继续向上回退。
  1. 当回退到根结点时,搜索结束,最后的集合L里的k个点,就是x的最近邻点。

3.2 时间复杂度

KD-Tree的平均时间复杂度为O(logN),N为训练样本的数量。

KD-Tree试用于训练样本数远大于空间维度的k近邻搜索。当空间维数接近训练样本数时,他的效率会迅速下降,几乎接近线性扫描。

3.3 实例说明

设我们想查询的点为 p=(−1,−5),设距离函数是普通的距离,我们想找距离目标点最近的 k=3 个点。如下:

首先我们按照构造好的KD-Tree,从根结点开始查找

和这个节点的 x 轴比较一下,p 的 x 轴更小。因此我们向左枝进行搜索:


接下来需要对比 y 轴


p 的 y 值更小,因此向左枝进行搜索:

这个节点只有一个子枝,就不需要对比了。由此找到了叶子节点 (−4.6,−10.55)。


在二维图上是蓝色的点


此时我们要执行第二步,将当前结点插入到集合L中,并记录下 L=[(−4.6,−10.55)]。访问过的节点就在二叉树上显示为被划掉的好了。

然后执行第三步,不是最顶端节点。我回退。上面的结点是 (−6.88,−5.4)。

执行 3a,因为我们记录下的点只有一个,小于 k=3,所以也将当前节点记录下,插入到集合L中,有 L=[(−4.6,−10.55),(−6.88,−5.4)].。 因为当前节点的左枝是空的,所以直接跳过,继续回退,判断不是顶部根节点

由于还是不够三个点,于是将当前点也插入到集合L中,有 L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)]。

此时发现,当前节点有其他的分枝,执行3b,计算得出 p 点和 L 中的三个点的距离分别是 6.62, 5.89, 3.10,但是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:

到垂线距离小于L中最大的距离,说明垂线的另一侧可能有更近的点

因此,在分割线的另一端可能有更近的点。于是我们在当前结点的另一个分枝从头执行步骤1。好,我们在红线这里:

此时处于x轴切分,因此要用 p 和这个节点比较 x 坐标:

p 的 x 坐标更大,因此探索右枝 (1.75,12.26),并且发现右枝已经是最底部节点,执行步骤2与3a。


经计算,(1.75,12.26) 与 p 的距离是 17.48,要大于 p 与 L 的距离,因此我们不将其放入记录中。


然后 回退,判断出不是顶端节点,往上爬。

执行3a,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。


因此,我们用这个新的节点替代 L 中离 p 最远的 (−4.6,−10.55)。

然后3b,我们比对 p 和当前节点的分割线的距离

image

这个距离小于 L 与 p 的最大距离,因此我们要到当前节点的另一个枝执行步骤1。当然,那个枝只有一个点。

计算距离发现这个点离 p 比 L 更远,因此不进行替代。

然后回退,不是根结点,我们向上爬

image

这个是已经访问过的了,所以再向上爬

再爬

此时到顶点了。所以完了吗?当然不,还要执行3b呢。现在是步骤1的回合。

我们进行计算比对发现顶端节点与p的距离比L还要更远,因此不进行更新。

然后计算 p 和分割线的距离发现也是更远。

因此也不需要检查另一个分枝。

判断当前节点是顶点,因此计算完成!输出距离 p 最近的三个样本是 L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)]。

3.3 代码

 def search(self, x, count=1):
        """KD-Tree的搜索"""
        nearest = [] # 记录近邻点的集合
        for i in range(count):
            nearest.append([-1, None])
        self.nearest = np.array(nearest)

        def recurve(node):
            """内方法,负责查找count个近邻点"""
            if node is not None:
                # 步骤1:怎么找叶子节点
                # 在哪个维度的分割线,0,1,0,1表示x,y,x,y
                axis = node.depth % self.n
                # 判断往左走or右走,递归,找到叶子结点
                daxis = x[axis] - node.data[axis]
                if daxis < 0:
                    recurve(node.lchild)
                else:
                    recurve(node.rchild)

                # 步骤2:满足的就插入到近邻点集合中
                # 求test点与此点的距离
                dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
                # 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
                for i, d in enumerate(self.nearest):
                    if d[0] < 0 or dist < d[0]:
                        self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
                        self.nearest = self.nearest[:-1]
                        break

                # 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
                n = list(self.nearest[:, 0]).count(-1)
                # -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
                # 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
                if self.nearest[-n-1, 0] > abs(daxis):
                    # 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
                    if daxis < 0:
                        recurve(node.rchild)
                    else:
                        recurve(node.lchild)
        
        recurve(self.KdTree)        # 调用根节点,开始查找
        knn = self.nearest[:, 1]    # knn为k个近邻结点
        belong = []                 # 记录k个近邻结点的分类
        for i in knn:
            belong.append(i.data[-1])
        b = max(set(belong), key=belong.count) # 找到测试点所属的分类

        return self.nearest, b

4 整体代码

import numpy as np
from math import sqrt
import pandas as pd
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

class Node:
    def __init__(self, data, depth=0, lchild=None, rchild=None):
        self.data = data     # 此结点
        self.depth = depth   # 树的深度
        self.lchild = lchild # 左子结点
        self.rchild = rchild # 右子节点


class KdTree:
    def __init__(self):
        self.KdTree = None
        self.n = 0
        self.nearest = None

    def create(self, dataSet, depth=0):
        """KD-Tree创建过程"""
        if len(dataSet) > 0:
            m, n = np.shape(dataSet)
            self.n = n - 1
            # 按照哪个维度进行分割,比如0:x轴,1:y轴
            axis = depth % self.n
            # 中位数
            mid = int(m / 2)
            # 按照第几个维度(列)进行排序
            dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
            # KD结点为中位数的结点,树深度为depth
            node = Node(dataSetcopy[mid], depth)
            if depth == 0:
                self.KdTree = node
            # 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
            node.lchild = self.create(dataSetcopy[:mid], depth+1)
            node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
            return node
        return None

    def preOrder(self, node):
        """遍历KD-Tree"""
        if node is not None:
            print(node.depth, node.data)
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

    def search(self, x, count=1):
        """KD-Tree的搜索"""
        nearest = [] # 记录近邻点的集合
        for i in range(count):
            nearest.append([-1, None])
        self.nearest = np.array(nearest)

        def recurve(node):
            """内方法,负责查找count个近邻点"""
            if node is not None:
                # 步骤1:怎么找叶子节点
                # 在哪个维度的分割线,0,1,0,1表示x,y,x,y
                axis = node.depth % self.n
                # 判断往左走or右走,递归,找到叶子结点
                daxis = x[axis] - node.data[axis]
                if daxis < 0:
                    recurve(node.lchild)
                else:
                    recurve(node.rchild)

                # 步骤2:满足的就插入到近邻点集合中
                # 求test点与此点的距离
                dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
                # 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
                for i, d in enumerate(self.nearest):
                    if d[0] < 0 or dist < d[0]:
                        self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
                        self.nearest = self.nearest[:-1]
                        break

                # 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
                n = list(self.nearest[:, 0]).count(-1)
                # -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
                # 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
                if self.nearest[-n-1, 0] > abs(daxis):
                    # 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
                    if daxis < 0:
                        recurve(node.rchild)
                    else:
                        recurve(node.lchild)
        
        recurve(self.KdTree)        # 调用根节点,开始查找
        knn = self.nearest[:, 1]    # knn为k个近邻结点
        belong = []                 # 记录k个近邻结点的分类
        for i in knn:
            belong.append(i.data[-1])
        b = max(set(belong), key=belong.count) # 找到测试点所属的分类

        return self.nearest, b


    def show_train():
        plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]')
        plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]')
        plt.xlabel('sepal length')
        plt.ylabel('sepal width')


if __name__ == "__main__":
    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df['label'] = iris.target
    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']

    data = np.array(df.iloc[:100, [0, 1, -1]])
    train, test = train_test_split(data, test_size=0.1)
    x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0])
    x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1])

    kdt = KdTree()
    kdt.create(train)
    kdt.preOrder(kdt.KdTree)

    score = 0
    for x in test:
        show_train()
        plt.scatter(x[0], x[1], c='red', marker='x')  # 测试点
        near, belong = kdt.search(x[:-1], 5)  # 设置临近点的个数
        if belong == x[-1]:
            score += 1
        print(x, "predict:", belong)
        print("nearest:")
        for n in near:
            print(n[1].data, "dist:", n[0])
            plt.scatter(n[1].data[0], n[1].data[1], c='green', marker='+')  # k个最近邻点
        plt.legend()
        plt.show()

    score /= len(test)
    print("score:", score)

声明:此文章为本人学习笔记,参考于:https://zhuanlan.zhihu.com/p/23966698

如果您觉得有用,欢迎关注我的公众号,我会不定期发布自己的学习笔记、AI资料、以及感悟,欢迎留言,与大家一起探索AI之路。

AI探索之路

相关文章

  • 2 Kd树的构造与搜索

    1 KD-Tree 实现kNN算法时,最简单的实现方法就是线性扫描,正如我们上一章节内容介绍的一样->K近邻算法,...

  • 构造kd树

    实现k近邻法时,主要考虑的问题是如何对训练数据进行快速的k近邻搜索,这点在特征空间的维数大及训练数据容量大时尤其重...

  • KNN 代码

    1、knn简单实现 2、knn回归 3、搜索优化KD树

  • k 近邻法

    k 近邻法 k 近邻算法 k 近邻模型 k 近邻法的实现:kd 树 搜索 kd 树 k 近邻模型实现 k 近邻模型...

  • kNN(构造kd树的实现)

    k近邻模型 1. 三个基本要素:k值的选择,距离度量,分类决策规则 2. 算法: 3. kd树的构造 python实现

  • 机器学习实战——KNN

    李航老师书寻找目标向量X最近的K个样本。通过KD树 实战通过线性搜索的方式(排序搜索)

  • kd树

    Kd-Tree,即K-dimensional tree,是一种高维索引树形数据结构,常用于在大规模的高维数据空间进...

  • KNN算法-4-算法优化-KD树

    KD树 KNN算法的重要步骤是对所有的实例点进行快速k近邻搜索。如果采用线性扫描(linear scan),要计算...

  • 数据结构之AVL树(自平衡二叉树)

    为什么需要AVL树? 在二叉搜索树专题讲过二叉搜索树一定程度上可以提高搜索效率,但是当原序列有序时,依据此序列构造...

  • knn,kd树

    一只兔子帮你理解 kNNhttps://www.joinquant.com/view/community/deta...

网友评论

    本文标题:2 Kd树的构造与搜索

    本文链接:https://www.haomeiwen.com/subject/wjebzqtx.html