1 KD-Tree
实现kNN算法时,最简单的实现方法就是线性扫描,正如我们上一章节内容介绍的一样->K近邻算法,需要计算输入实例与每一个训练样本的距离。当训练集很大时,会非常耗时。
为了提高kNN搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,KD-Tree就是其中的一种方法。
kd树是一个二叉树结构,相当于不断的用垂线将k维空间进行切分,构成一系列的k维超矩形区域。
2 如何构造KD-Tree
2.1 KD-Tree算法如下:
K维空间数据集
其中
-
构造根节点
选择为坐标轴,将T中所有实例以坐标为中位数,垂直轴切成两个矩形,由根节点生成深度为1的左、右两个子节点:左子节点对应的坐标都小于切分点,右子节点坐标都大于切分点坐标。 -
重复:对深度为j的节点,选择为切分的坐标轴, ,以该节点再次将矩形区域切分为两个子区域。
-
直到两个子区域没有实力存在时停止,从而形成KD-Tree的区域划分。
2.2 举例说明KD-Tree构造
随机生成 13 个点作为我们的数据集
13个随机点分布首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标
根结点并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左分支,x坐标大于 6.27 的点用于构建右分支。
在下一步中 ,对应 y 轴,左右两边再按照 y 轴的排序进行切分,中位点记载于左右枝的节点。得到下面的树,左边的 x 是指这该层的节点都是沿 x 轴进行分割的。
空间的切分如下
下一步中,对应 x 轴,所以下面再按照 x 坐标进行排序和切分,有
最后只剩下了叶子结点,就此完成了 kd 树的构造。
2.3 构造代码
class Node:
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = data # 此结点
self.depth = depth # 树的深度
self.lchild = lchild # 左子结点
self.rchild = rchild # 右子节点
class KdTree:
def __init__(self):
self.KdTree = None
self.n = 0
self.nearest = None
def create(self, dataSet, depth=0):
"""KD-Tree创建过程"""
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n - 1
# 按照哪个维度进行分割,比如0:x轴,1:y轴
axis = depth % self.n
# 中位数
mid = int(m / 2)
# 按照第几个维度(列)进行排序
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
# KD结点为中位数的结点,树深度为depth
node = Node(dataSetcopy[mid], depth)
if depth == 0:
self.KdTree = node
# 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
node.lchild = self.create(dataSetcopy[:mid], depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
return node
return None
3 搜索KD-Tree
输入:已构造的kd树,目标点x
输出:x的k个最近邻集合L
3.1 KD-Tree的最近邻搜索算法
-
从根结点出发,递归向下访问KD-Tree,如果目标点x当前维小于切分点坐标,移动到左子节点,否则右子节点,直到子节点为叶子结点为止。
-
以此叶子结点为最近邻的点,插入到集合L中
-
递归向上回退,在这个节点进行以下操作:
- a 如果该节点比L里的点更近,则替换集合L中距离最大的点。
- b 目标点到此节点的分割线垂直的距离为d,判断集合L中距离最大的点与 d 相比较,如果比d大,说明d的另一侧区域中有可能有比集合L中距离要小,因此需要查看d的左右两个子节点的距离。
如果集合L中距离最大的点比 d小,那说明另一侧区域的点距离目标点的距离都比d大,因此不用查找了,继续向上回退。
- 当回退到根结点时,搜索结束,最后的集合L里的k个点,就是x的最近邻点。
3.2 时间复杂度
KD-Tree的平均时间复杂度为,N为训练样本的数量。
KD-Tree试用于训练样本数远大于空间维度的k近邻搜索。当空间维数接近训练样本数时,他的效率会迅速下降,几乎接近线性扫描。
3.3 实例说明
设我们想查询的点为 p=(−1,−5),设距离函数是普通的距离,我们想找距离目标点最近的 k=3 个点。如下:
首先我们按照构造好的KD-Tree,从根结点开始查找
和这个节点的 x 轴比较一下,p 的 x 轴更小。因此我们向左枝进行搜索:
接下来需要对比 y 轴
p 的 y 值更小,因此向左枝进行搜索:
这个节点只有一个子枝,就不需要对比了。由此找到了叶子节点 (−4.6,−10.55)。
在二维图上是蓝色的点
此时我们要执行第二步,将当前结点插入到集合L中,并记录下 L=[(−4.6,−10.55)]。访问过的节点就在二叉树上显示为被划掉的好了。
然后执行第三步,不是最顶端节点。我回退。上面的结点是 (−6.88,−5.4)。
执行 3a,因为我们记录下的点只有一个,小于 k=3,所以也将当前节点记录下,插入到集合L中,有 L=[(−4.6,−10.55),(−6.88,−5.4)].。 因为当前节点的左枝是空的,所以直接跳过,继续回退,判断不是顶部根节点
由于还是不够三个点,于是将当前点也插入到集合L中,有 L=[(−4.6,−10.55),(−6.88,−5.4),(1.24,−2.86)]。
此时发现,当前节点有其他的分枝,执行3b,计算得出 p 点和 L 中的三个点的距离分别是 6.62, 5.89, 3.10,但是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:
到垂线距离小于L中最大的距离,说明垂线的另一侧可能有更近的点因此,在分割线的另一端可能有更近的点。于是我们在当前结点的另一个分枝从头执行步骤1。好,我们在红线这里:
此时处于x轴切分,因此要用 p 和这个节点比较 x 坐标:
p 的 x 坐标更大,因此探索右枝 (1.75,12.26),并且发现右枝已经是最底部节点,执行步骤2与3a。
经计算,(1.75,12.26) 与 p 的距离是 17.48,要大于 p 与 L 的距离,因此我们不将其放入记录中。
然后 回退,判断出不是顶端节点,往上爬。
执行3a,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。
因此,我们用这个新的节点替代 L 中离 p 最远的 (−4.6,−10.55)。
然后3b,我们比对 p 和当前节点的分割线的距离
image这个距离小于 L 与 p 的最大距离,因此我们要到当前节点的另一个枝执行步骤1。当然,那个枝只有一个点。
计算距离发现这个点离 p 比 L 更远,因此不进行替代。
然后回退,不是根结点,我们向上爬
image这个是已经访问过的了,所以再向上爬
再爬
此时到顶点了。所以完了吗?当然不,还要执行3b呢。现在是步骤1的回合。
我们进行计算比对发现顶端节点与p的距离比L还要更远,因此不进行更新。
然后计算 p 和分割线的距离发现也是更远。
因此也不需要检查另一个分枝。
判断当前节点是顶点,因此计算完成!输出距离 p 最近的三个样本是 L=[(−6.88,−5.4),(1.24,−2.86),(−2.96,−2.5)]。
3.3 代码
def search(self, x, count=1):
"""KD-Tree的搜索"""
nearest = [] # 记录近邻点的集合
for i in range(count):
nearest.append([-1, None])
self.nearest = np.array(nearest)
def recurve(node):
"""内方法,负责查找count个近邻点"""
if node is not None:
# 步骤1:怎么找叶子节点
# 在哪个维度的分割线,0,1,0,1表示x,y,x,y
axis = node.depth % self.n
# 判断往左走or右走,递归,找到叶子结点
daxis = x[axis] - node.data[axis]
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)
# 步骤2:满足的就插入到近邻点集合中
# 求test点与此点的距离
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
# 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
for i, d in enumerate(self.nearest):
if d[0] < 0 or dist < d[0]:
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
self.nearest = self.nearest[:-1]
break
# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
n = list(self.nearest[:, 0]).count(-1)
# -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
# 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
if self.nearest[-n-1, 0] > abs(daxis):
# 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild)
recurve(self.KdTree) # 调用根节点,开始查找
knn = self.nearest[:, 1] # knn为k个近邻结点
belong = [] # 记录k个近邻结点的分类
for i in knn:
belong.append(i.data[-1])
b = max(set(belong), key=belong.count) # 找到测试点所属的分类
return self.nearest, b
4 整体代码
import numpy as np
from math import sqrt
import pandas as pd
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
class Node:
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = data # 此结点
self.depth = depth # 树的深度
self.lchild = lchild # 左子结点
self.rchild = rchild # 右子节点
class KdTree:
def __init__(self):
self.KdTree = None
self.n = 0
self.nearest = None
def create(self, dataSet, depth=0):
"""KD-Tree创建过程"""
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n - 1
# 按照哪个维度进行分割,比如0:x轴,1:y轴
axis = depth % self.n
# 中位数
mid = int(m / 2)
# 按照第几个维度(列)进行排序
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
# KD结点为中位数的结点,树深度为depth
node = Node(dataSetcopy[mid], depth)
if depth == 0:
self.KdTree = node
# 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
node.lchild = self.create(dataSetcopy[:mid], depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
return node
return None
def preOrder(self, node):
"""遍历KD-Tree"""
if node is not None:
print(node.depth, node.data)
self.preOrder(node.lchild)
self.preOrder(node.rchild)
def search(self, x, count=1):
"""KD-Tree的搜索"""
nearest = [] # 记录近邻点的集合
for i in range(count):
nearest.append([-1, None])
self.nearest = np.array(nearest)
def recurve(node):
"""内方法,负责查找count个近邻点"""
if node is not None:
# 步骤1:怎么找叶子节点
# 在哪个维度的分割线,0,1,0,1表示x,y,x,y
axis = node.depth % self.n
# 判断往左走or右走,递归,找到叶子结点
daxis = x[axis] - node.data[axis]
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)
# 步骤2:满足的就插入到近邻点集合中
# 求test点与此点的距离
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
# 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
for i, d in enumerate(self.nearest):
if d[0] < 0 or dist < d[0]:
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
self.nearest = self.nearest[:-1]
break
# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
n = list(self.nearest[:, 0]).count(-1)
# -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
# 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
if self.nearest[-n-1, 0] > abs(daxis):
# 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild)
recurve(self.KdTree) # 调用根节点,开始查找
knn = self.nearest[:, 1] # knn为k个近邻结点
belong = [] # 记录k个近邻结点的分类
for i in knn:
belong.append(i.data[-1])
b = max(set(belong), key=belong.count) # 找到测试点所属的分类
return self.nearest, b
def show_train():
plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]')
plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
if __name__ == "__main__":
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
data = np.array(df.iloc[:100, [0, 1, -1]])
train, test = train_test_split(data, test_size=0.1)
x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0])
x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1])
kdt = KdTree()
kdt.create(train)
kdt.preOrder(kdt.KdTree)
score = 0
for x in test:
show_train()
plt.scatter(x[0], x[1], c='red', marker='x') # 测试点
near, belong = kdt.search(x[:-1], 5) # 设置临近点的个数
if belong == x[-1]:
score += 1
print(x, "predict:", belong)
print("nearest:")
for n in near:
print(n[1].data, "dist:", n[0])
plt.scatter(n[1].data[0], n[1].data[1], c='green', marker='+') # k个最近邻点
plt.legend()
plt.show()
score /= len(test)
print("score:", score)
声明:此文章为本人学习笔记,参考于:https://zhuanlan.zhihu.com/p/23966698
如果您觉得有用,欢迎关注我的公众号,我会不定期发布自己的学习笔记、AI资料、以及感悟,欢迎留言,与大家一起探索AI之路。
AI探索之路
网友评论