根据两篇博文实现的KD-tree。k-d tree算法原理及实现,
最近邻查找算法kd-tree。
package kdtree;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
/**
* 一个2-D的 K-D tree
*/
public class TwoDimonsionalTree {
/**
* 一个节点
*
* @param <T>
*/
class Node<T> {
T value;
Node<T> left;
Node<T> right;
int depth;
@Override
public String toString() {
return "Node{" +
"value=" + value +
'}';
}
}
/**
* 维度x
*/
private static final int CUT_DIMONSION_X = 1;
/**
* 维度y
*/
private static final int CUT_DIMONSION_Y = 0;
/**
* 维度x比较器
*/
private static final Comparator<Position> xComparator = Comparator.comparingInt(Position::getX);
/**
* 维度y比较器
*/
private static final Comparator<Position> yComparator = Comparator.comparingInt(Position::getY);
/**
* 树根节点
*/
Node<Position> root;
/**
* list结构存储
*/
List<Position> list = new ArrayList<>();
int depth;
/**
* 构建
*/
public void rebuild() {
root = rebuild(list, 1);
}
/**
* 递归构建
*
* @param points
* @param depth
*/
private Node<Position> rebuild(List<Position> points, int depth) {
if (points == null || points.isEmpty()) {
return null;
}
Node<Position> node = new Node<>();
int middle = points.size() / 2;
points.sort(getComparatorByDepth(depth));
node.value = points.get(middle);
node.depth = depth;
node.left = rebuild(points.subList(0, middle), depth + 1);
node.right = rebuild(points.subList(middle + 1, points.size()), depth + 1);
return node;
}
private Comparator<Position> getComparatorByDepth(int depth) {
if (calCutDim(depth) == CUT_DIMONSION_X) {
return xComparator;
} else {
return yComparator;
}
}
/**
* 得到该深度的 切面维
*
* @param depth
* @return
*/
private int calCutDim(int depth) {
return depth % 2;
}
/**
* 比较某维度
*
* @param cutDim
* @param position1
* @param position2
* @return
*/
private Position min(int cutDim, Position position1, Position position2) {
if (cutDim == CUT_DIMONSION_X) {
if (position1.getX() < position2.getX()) {
return position1;
}
} else {
if (position1.getY() < position2.getY()) {
return position1;
}
}
return position2;
}
/**
* 两点距离
*
* @return
*/
private double calDistance(Position p1, Position p2) {
int dis2 = (p1.getX() - p2.getX()) * (p1.getX() - p2.getX()) + (p1.getY() - p2.getY()) * (p1.getY() - p2.getY());
return Math.sqrt(dis2);
}
/**
* 找到距离target最近的点
*
* @param target
* @return
*/
public Position findClosest(Position target) {
if (root == null) {
throw new RuntimeException("tree has not been rebuild or is empty");
}
List<Node<Position>> list = new ArrayList<>();
Tuple<Position, Double> minTuple = search(target, root, list, null);
while (!list.isEmpty()) {
minTuple = reCall(list, minTuple, target);
}
return minTuple.getLeft();
}
/**
* 搜索
*
* @param target
* @param root
* @param list
* @param minTuple
* @return
*/
private Tuple<Position, Double> search(Position target, Node<Position> root, List<Node<Position>> list, Tuple<Position, Double> minTuple) {
Node<Position> temp = root;
while (temp.left != null || temp.right != null) {
list.add(temp);
if (min(calCutDim(temp.depth), target, temp.value) == target) {
if (temp.left == null) {
break;
}
temp = temp.left;
} else {
if (temp.right == null) {
break;
}
temp = temp.right;
}
}
if (minTuple == null) {
return new Tuple<>(temp.value, calDistance(temp.value, target));
}
double dis = calDistance(temp.value, target);
if (dis < minTuple.getRight()) {
return new Tuple<>(temp.value, dis);
}
return minTuple;
}
/**
* 回溯
*
* @param list
* @param minTuple
* @param target
* @return
*/
private Tuple<Position, Double> reCall(List<Node<Position>> list, Tuple<Position, Double> minTuple, Position target) {
Node<Position> temp = list.get(list.size() - 1);
list.remove(temp);
if (temp.left == null && temp.right == null) {
double dis = calDistance(temp.value, target);
if (dis < minTuple.getRight()) {
return new Tuple<>(temp.value, dis);
}
} else {
if (calDisCutDim(target, temp.value, calCutDim(temp.depth)) > minTuple.getRight()) {
return minTuple;
}
double dis = calDistance(temp.value, target);
if (dis < minTuple.getRight()) {
return new Tuple<>(temp.value, dis);
}
if (min(calCutDim(temp.depth), target, temp.value) == target) {
if (temp.right == null) {
minTuple = search(target, temp.right, list, minTuple);
}
} else {
if (temp.left == null) {
minTuple = search(target, temp.left, list, minTuple);
}
}
}
return minTuple;
}
private int calDisCutDim(Position p1, Position p2, int cutDim) {
if (cutDim == CUT_DIMONSION_X) {
return Math.abs(p1.getX() - p2.getY());
} else {
return Math.abs(p1.getY() - p2.getY());
}
}
public static void main(String[] args) {
TwoDimonsionalTree tree
= new TwoDimonsionalTree();
List<Position> list = new ArrayList<>();
list.add(new Position(7,2));
list.add(new Position(5,4));
list.add(new Position(9,6));
list.add(new Position(2,3));
list.add(new Position(4,7));
list.add(new Position(8,1));
tree.list = list;
tree.rebuild();
tree.findClosest(new Position(3,5));
}
}
网友评论