美文网首页
第三章 kd-tree

第三章 kd-tree

作者: sadamu0912 | 来源:发表于2019-11-15 22:54 被阅读0次
    import numpy as np
    from operator import itemgetter
    # 获取两个特征点的二范数,也叫欧式距离
    def _get_euclidean_distance(feature1,feature2):
        return ((feature1-feature2) ** 2).sum() ** 0.5
    class Node(object):
        #把args打包成tuple使用,
        #比如函数类型 function(x, y, *args),参数: 1,2,3,4,5,会变成1,2,(3,4,5)
        #  **kwargs 打包参数成dict使用。
        #比如函数类型 function(**kwargs) function(a=1, b=2, c=3),会变成 {'a':1,'b':2,'c':3}
        #写成这样子的好处是一个构造函数,可以当多个构造函数使用
        def __init__(self,*args,**kwargs):
            self.father=kwargs.get("father")
            self.left=kwargs.get("left")
            self.right=kwargs.get("right")
            #节点的特征
            self.feature=kwargs.get("feature")
            #节点的分隔维度
            self.split=kwargs.get("split")
        #判断两个节点是否相等  ,待优化
        def equals(self,node):
            if(self.feature==node.feature).all():
                return True
            else:
                return False
    class KDTree(object):
        def __init__(self):
            self.root =Node()
         #构造kd树的时候,选择方差大的维度,进行分割,使得kd树数据尽量分布均匀
        def _choose_feature(self,X,idxs):
            m=len(X[0])
            variances=map(lambda j:(j,self._get_variance(X,idxs,j)),range(m))
            return max(variances,key=lambda x:x[1])[0]
         #X是样本集,是个二维数组,idxs是要求方差的样本点的索引,一维数组,dimension是对哪个维度进行 
         #方差计算
        def _get_variance(self,X,idxs,dimension):
            n=len(idxs)
            col_sum=col_sum_sqrt=0
            for idx in idxs:
                xi=X[idx][dimension]
                col_sum+=xi
                col_sum_sqrt+=xi ** 2
                ## D(X)=E(X^2)-[E(X)]^2
                return col_sum_sqrt/n -(col_sum/n) ** 2
        #计算输入点point,到nd所在的分离超平面的距离
        def _get_hyper_plane_dist(self,point,nd):
            #节点的分割维度
             j=nd.split
            #分割维度上的差到超平面的距离
            return abs(point[j]-nd.feature[j])
        def search_node(self,point,node):
            while node!=None and (node.left or node.right):
                if(point.feature[node.split]<node.feature[node.split]):
                    node=node.left
                elif point.feature[node.split]>node.feature[node.split]:
                    node=node.right
                else:
                    if(node.equals(point)):
                        break;
            return node
        #记录节点的查找到叶子节点的搜索路径,以便后面的回溯
        def get_search_path(self,point,node):
            search_path=[node.feature]
            while node!=Node and (node.left or node.right):
                if point[node.split]<node.feature[node.split]:
                    node=node.left
                    if node!=None:
                        search_path.append(node.feature)
                else:
                    node=node.right
                    if node!=None:
                        search_path.append(node.feature)
            return search_path
        #记录搜索路径的向左向右的集合,也算是记录路径的一种
        def get_search_directions(self,point,node):
            search_directions=[]
            while node!=Node and (node.left or node.right):
                if point[node.split]<node.feature[node.split]:
                    node=node.left
                    if node!=None:
                        search_directions.append(1)
                else:
                    node=node.right
                    if node!=None:
                        search_directions.append(0)
            return search_directions
        #获取样本集X,二维数组,在feature维度,的中值的索引
        def _get_median_idx(self,X,idxs,feature):
            n=len(idxs)
            k=n//2
            col=map(lambda i:(i,X[i][feature]),idxs)
            sorted_idxs=map(lambda x:x[0],sorted(col,key=lambda x:x[1]))
            median_idx=list(sorted_idxs)[k]
            return median_idx
        def _split_feature(self,X,idxs,feature,median_idx):
            idxs_split=[[],[]]
            split_val=X[median_idx][feature]
            for idx in idxs:
                if idx== median_idx:
                    continue
                xi=X[idx][feature]
                if xi<split_val:
                    idxs_split[0].append(idx)
                else:
                    idxs_split[1].append(idx)
            return idxs_split
          #递归构建kd树
        def build_tree(self,points,depth):
            if 0 == len(points):
                return None
            cutting_dim= depth % len(points[0])
            medium_index=int(len(points)/2)
            pointsNew=sorted(points,key=itemgetter(cutting_dim))
            node=Node(feature=pointsNew[medium_index],split=cutting_dim)
            node.left=self.build_tree(pointsNew[:medium_index],depth+1)
            node.right=self.build_tree(pointsNew[medium_index+1:],depth+1)
            return node
        #最近邻查找
        def nearest_neighbor_search(self,point,tree):
            search_path=self.get_search_path(point,tree)
            search_directions = self.get_search_directions(point,tree)
            #当前最近节点
            node_best=search_path.pop()
            #回溯节点
            back_node=search_path.pop()
            #回溯节点方向(最后一步的方向)
            back_direction=search_directions.pop()
            #一直循环,直到回溯节点为空
            while all(back_node)!=None and len(search_path):
                 #当前里输入点,最近距离
                dis_best=_get_euclidean_distance(node_best,point)
             #输入点,到超平面的距离  
             dis_hyper_plane=self._get_hyper_plane_dist(point,Node(feature=back_node,split=back_direction))
                if dis_best>= dis_hyper_plane:
                    node_best=back_node
                    dis_best=_get_euclidean_distance(node_best,point)
                    #找到回溯节点
                    searchedBackNode=self.search_node(Node(feature=back_node,split=back_direction),tree)
                    if(back_direction==0):
                        #dis_best>dis_hyper_plane就是输入点到最近点的距离,画圆,和超平面有交点。
                        #说明另外一个子空间,可能有比当前最近点,更近的。 需要回溯到超平面所在的分隔
                        #所在的另外一个子空间,去寻找更近点
                        if searchedBackNode.left:
                            search_path.append(searchedBackNode.left.feature)
                    else:
                        if searchedBackNode.right:
                            search_path.append(searchedBackNode.right.feature)
                if len(search_path)>0:
                    back_node= search_path.pop()
                if len(search_directions)>0:
                    back_direction=search_directions.pop()
            return node_best
    tree=KDTree()
    features=np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
    node=tree.build_tree(features,0)
    ##print(tree.nearest_neighbor_search([2.1,3.1],node))
    print(tree.nearest_neighbor_search([2,4.5],node))
    

    代码待优化,回溯节点的时候,没有考虑周全。

    参考文档:
    <https://blog.csdn.net/pipisorry/article/details/52186307>

    相关文章

      网友评论

          本文标题:第三章 kd-tree

          本文链接:https://www.haomeiwen.com/subject/dqodictx.html