美文网首页
最近邻查找算法KDtree

最近邻查找算法KDtree

作者: overflow_e4e4 | 来源:发表于2019-06-25 14:17 被阅读0次

根据两篇博文实现的KD-tree。k-d tree算法原理及实现
最近邻查找算法kd-tree

package kdtree;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

/**
 * 一个2-D的 K-D tree
 */
public class TwoDimonsionalTree {


    /**
     * 一个节点
     *
     * @param <T>
     */
    class Node<T> {
        T value;
        Node<T> left;
        Node<T> right;
        int depth;

        @Override
        public String toString() {
            return "Node{" +
                    "value=" + value +
                    '}';
        }
    }

    /**
     * 维度x
     */
    private static final int CUT_DIMONSION_X = 1;

    /**
     * 维度y
     */
    private static final int CUT_DIMONSION_Y = 0;


    /**
     * 维度x比较器
     */
    private static final Comparator<Position> xComparator = Comparator.comparingInt(Position::getX);

    /**
     * 维度y比较器
     */
    private static final Comparator<Position> yComparator = Comparator.comparingInt(Position::getY);

    /**
     * 树根节点
     */
    Node<Position> root;

    /**
     * list结构存储
     */
    List<Position> list = new ArrayList<>();

    int depth;

    /**
     * 构建
     */
    public void rebuild() {
        root = rebuild(list, 1);

    }

    /**
     * 递归构建
     *
     * @param points
     * @param depth
     */
    private Node<Position> rebuild(List<Position> points, int depth) {
        if (points == null || points.isEmpty()) {
            return null;
        }
        Node<Position> node = new Node<>();

        int middle = points.size() / 2;
        points.sort(getComparatorByDepth(depth));
        node.value = points.get(middle);
        node.depth = depth;
        node.left = rebuild(points.subList(0, middle), depth + 1);
        node.right = rebuild(points.subList(middle + 1, points.size()), depth + 1);
        return node;
    }

    private Comparator<Position> getComparatorByDepth(int depth) {
        if (calCutDim(depth) == CUT_DIMONSION_X) {
            return xComparator;
        } else {
            return yComparator;
        }
    }

    /**
     * 得到该深度的 切面维
     *
     * @param depth
     * @return
     */
    private int calCutDim(int depth) {
        return depth % 2;
    }


    /**
     * 比较某维度
     *
     * @param cutDim
     * @param position1
     * @param position2
     * @return
     */
    private Position min(int cutDim, Position position1, Position position2) {
        if (cutDim == CUT_DIMONSION_X) {
            if (position1.getX() < position2.getX()) {
                return position1;
            }
        } else {
            if (position1.getY() < position2.getY()) {
                return position1;
            }
        }
        return position2;

    }

    /**
     * 两点距离
     *
     * @return
     */
    private double calDistance(Position p1, Position p2) {

        int dis2 = (p1.getX() - p2.getX()) * (p1.getX() - p2.getX()) + (p1.getY() - p2.getY()) * (p1.getY() - p2.getY());
        return Math.sqrt(dis2);

    }

    /**
     * 找到距离target最近的点
     *
     * @param target
     * @return
     */
    public Position findClosest(Position target) {

        if (root == null) {
            throw new RuntimeException("tree has not been rebuild or is empty");
        }

        List<Node<Position>> list = new ArrayList<>();
        Tuple<Position, Double> minTuple = search(target, root, list, null);
        while (!list.isEmpty()) {
            minTuple = reCall(list, minTuple, target);
        }
        return minTuple.getLeft();
    }

    /**
     * 搜索
     *
     * @param target
     * @param root
     * @param list
     * @param minTuple
     * @return
     */
    private Tuple<Position, Double> search(Position target, Node<Position> root, List<Node<Position>> list, Tuple<Position, Double> minTuple) {

        Node<Position> temp = root;
        while (temp.left != null || temp.right != null) {
            list.add(temp);
            if (min(calCutDim(temp.depth), target, temp.value) == target) {
                if (temp.left == null) {
                    break;
                }
                temp = temp.left;
            } else {
                if (temp.right == null) {
                    break;
                }
                temp = temp.right;
            }

        }
        if (minTuple == null) {
            return new Tuple<>(temp.value, calDistance(temp.value, target));
        }
        double dis = calDistance(temp.value, target);
        if (dis < minTuple.getRight()) {
            return new Tuple<>(temp.value, dis);

        }
        return minTuple;
    }

    /**
     * 回溯
     *
     * @param list
     * @param minTuple
     * @param target
     * @return
     */
    private Tuple<Position, Double> reCall(List<Node<Position>> list, Tuple<Position, Double> minTuple, Position target) {
        Node<Position> temp = list.get(list.size() - 1);
        list.remove(temp);

        if (temp.left == null && temp.right == null) {
            double dis = calDistance(temp.value, target);
            if (dis < minTuple.getRight()) {
                return new Tuple<>(temp.value, dis);

            }
        } else {
            if (calDisCutDim(target, temp.value, calCutDim(temp.depth)) > minTuple.getRight()) {
                return minTuple;
            }
            double dis = calDistance(temp.value, target);
            if (dis < minTuple.getRight()) {
                return new Tuple<>(temp.value, dis);

            }
            if (min(calCutDim(temp.depth), target, temp.value) == target) {
                if (temp.right == null) {
                    minTuple = search(target, temp.right, list, minTuple);
                }
            } else {
                if (temp.left == null) {
                    minTuple = search(target, temp.left, list, minTuple);
                }
            }

        }
        return minTuple;
    }

    private int calDisCutDim(Position p1, Position p2, int cutDim) {
        if (cutDim == CUT_DIMONSION_X) {
            return Math.abs(p1.getX() - p2.getY());
        } else {
            return Math.abs(p1.getY() - p2.getY());

        }
    }

    public static void main(String[] args) {
        TwoDimonsionalTree tree
                = new TwoDimonsionalTree();
        List<Position> list = new ArrayList<>();
        list.add(new Position(7,2));
        list.add(new Position(5,4));
        list.add(new Position(9,6));
        list.add(new Position(2,3));
        list.add(new Position(4,7));
        list.add(new Position(8,1));
        tree.list = list;
        tree.rebuild();


        tree.findClosest(new Position(3,5));




    }

}


相关文章

网友评论

      本文标题:最近邻查找算法KDtree

      本文链接:https://www.haomeiwen.com/subject/kwfjqctx.html