Scala实现:KD-Tree(k-dimensional tree)
kd-tree是一种分割k维数据空间的数据结构。主要应用于多维空间数据的搜索,经常使用在SIFT、KNN等多维数据搜索的场景中,以KNN(K近邻)为例,使用线性搜索的方式效率低下,k-d树本质是对多维空间的划分,其每个节点都为k维点的二叉树kd-tree,因此可以大大提高搜索效率。
KD-Tree的构建步骤:
kd树实现步骤.jpg上述文字引自李航博士的《统计学习方法》
以{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}数据集为例构建KD-Tree。
KD-Tree空间划分示意图如下:
划分结果.png kdtree树结构.jpg关于三维数据的空间划分示意图如下所示
三维kdtree空间划分更多维度的数据划分只能靠脑补了······
KD-Tree最邻近搜索:
-
从根节点开始,递归的往下访问kd树,比较目标点与切分点在当前切分维度的大小,小于则移动到左子结点,大于则移动到右子结点,知道子结点为叶结点为止。
-
一旦移动到叶结点,将该结点当作"当前最邻近点"。
-
递归回退,对每个经过的叶结点递归地执行下列操作:
- 如果当前所在点比"当前最邻近点"更靠近输入点,则将其变为当前最邻近点。
-
- 当前最近点一定存在于该节点一个子结点对应的区域,检查另一子结点对应的区域是否与目标点为球心,以目标点与“当前最邻近点”之间的距离为半径的超球体相交:
- 1.如果相交,可能在另一结点对应之区域内存在距离目标点更近的点,移动到另一子结点,接着递归地进行最近邻搜索;
- 2.如果不相交,向上回退。
- 当回退到根节点时,搜索结束。最后的“当前最邻近点"即为x的最近邻点。
Scala代码实现
定义树节点
/**
*
* @param value 节点数据
* @param dim 当前切分维度
* @param left 左子结点
* @param right 右子结点
*/
case class TreeNode(value: Seq[Double],
dim: Int,
var left: TreeNode,
var right: TreeNode) {
var parent: TreeNode = _ //父结点
var brotherNode: TreeNode = _ //兄弟结点
if (left != null) {
left.parent = this
left.brotherNode = right
}
if (right != null) {
right.parent = this
right.brotherNode = left
}
}
创建KD-Tree
/**
*
* @param value 数据序列
* @param dim 当前划分的维度
* @param shape 数据维数
* @return
*/
def creatKdTree(value: Seq[Seq[Double]], dim: Int, shape: Int): TreeNode = {
// 数据按照当前划分的维度排序
val sorted = value.sortBy(_ (dim))
//中间位置的索引
val midIndex: Int = value.length / 2
sorted match {
// 当节点为空时,返回null
case Nil => null
//节点不为空时,递归调用
case _ =>
val left = sorted.slice(0, midIndex)
val right = sorted.slice(midIndex + 1, value.length)
val leftNode = creatKdTree(left, (dim + 1) % shape, shape) //左子节点递归创建树
val rightNode = creatKdTree(right, (dim + 1) % shape, shape) //右子节点递归创建树
TreeNode(sorted(midIndex), dim, leftNode, rightNode)
}
}
最近邻查找
// 欧式距离算法
def euclidean(p1: Seq[Double], p2: Seq[Double]) = {
require(p1.size == p2.size)
val d = p1
.zip(p2)
.map(tp => math.pow(tp._1 - tp._2, 2))
.sum
math.sqrt(d)
}
/**
*
* @param treeNode kdtree
* @param data 查询点
* 最近邻搜索
*/
def nearestSearch(treeNode: TreeNode, data: Seq[Double]): TreeNode = {
var nearestNode: TreeNode = null //当前最近节点
var minDist: Double = Double.MaxValue //当前最小距离
def finder(treeNode: TreeNode): TreeNode = {
treeNode match {
case null => nearestNode
case _ =>
val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)
if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)
val distc = euclidean(treeNode.value, data)
if (distc <= minDist) {
minDist = distc
nearestNode = treeNode
}
// 目标点与当前节点相交
if (math.abs(dimr) < minDist)
if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)
nearestNode
}
}
finder(treeNode)
}
结果查看
val nodes: Seq[Seq[Double]] =
Seq(Seq(2, 3), Seq(5, 4), Seq(9, 6), Seq(4, 7), Seq(8, 1), Seq(7, 2))
val treeNode: TreeNode = KdTree.creatKdTree(nodes, 0, 2)
println(treeNode)
println(KdTree.nearestSearch(treeNode, Seq(2.1, 4.5)).value)
println("==============")
nodes.map(x => {
val d = KdTree.euclidean(x, Seq(2.1, 4.5))
(d, x)
})
.sortBy(_._1)
.foreach(println)
TreeNode(List(7.0, 2.0),0,TreeNode(List(5.0, 4.0),1,TreeNode(List(2.0, 3.0),0,null,null),TreeNode(List(4.0, 7.0),0,null,null)),TreeNode(List(9.0, 6.0),1,TreeNode(List(8.0, 1.0),0,null,null),null))
List(2.0, 3.0)
==============
(1.503329637837291,List(2.0, 3.0))
(2.9427877939124323,List(5.0, 4.0))
(3.1400636936215163,List(4.0, 7.0))
(5.500909015790027,List(7.0, 2.0))
(6.860029154456998,List(8.0, 1.0))
(7.061161377563892,List(9.0, 6.0))
TODO K近邻查找(KNN)
/**
* 从root节点开始,DFS搜索直到叶子节点,同时在stack中顺序存储已经访问的节点。
* 如果搜索到叶子节点,当前的叶子节点被设为最近邻节点。
* 然后通过stack回溯:
* 如果当前点的距离比最近邻点距离近,更新最近邻节点.
* 然后检查以最近距离为半径的圆是否和父节点的超平面相交.
* 如果相交,则必须到父节点的另外一侧,用同样的DFS搜索法,开始检查最近邻节点。
* 如果不相交,则继续往上回溯,而父节点的另一侧子节点都被淘汰,不再考虑的范围中.
* 当搜索回到root节点时,搜索完成,得到最近邻节点。
*
* @param treeNode
* @param data
* @param k
* @return
*/
def knn(treeNode: TreeNode, data: Seq[Double], k: Int) = {
var resArr = new Array[(Double, TreeNode)](k)
.map(_ => (Double.MaxValue, null))
.asInstanceOf[Array[(Double, TreeNode)]]
def finder(treeNode: TreeNode): TreeNode = {
if (treeNode != null) {
val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)
if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)
val distc: Double = distanceUtils.euclidean(treeNode.value, data)
if (distc < resArr.last._1 ) {
resArr.update(k - 1, (distc, treeNode))
resArr = resArr.sortBy(_._1)
}
if (math.abs(dimr) < resArr.last._1)
if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)
}
resArr.last._2
}
finder(treeNode)
resArr
}
KNN结果查看
KdTree
.knn(treeNode, Seq(2.1, 4.5), 3)
.map(x => (x._1, x._2.value))
.foreach(println)
(1.503329637837291,List(2.0, 3.0))
(2.9427877939124323,List(5.0, 4.0))
(3.1400636936215163,List(4.0, 7.0))
参考资料
https://baike.baidu.com/item/kd-tree/2302515?fr=aladdin#7_1
《统计学习方法》
网友评论