import numpy as np
from operator import itemgetter
# 获取两个特征点的二范数,也叫欧式距离
def _get_euclidean_distance(feature1,feature2):
return ((feature1-feature2) ** 2).sum() ** 0.5
class Node(object):
#把args打包成tuple使用,
#比如函数类型 function(x, y, *args),参数: 1,2,3,4,5,会变成1,2,(3,4,5)
# **kwargs 打包参数成dict使用。
#比如函数类型 function(**kwargs) function(a=1, b=2, c=3),会变成 {'a':1,'b':2,'c':3}
#写成这样子的好处是一个构造函数,可以当多个构造函数使用
def __init__(self,*args,**kwargs):
self.father=kwargs.get("father")
self.left=kwargs.get("left")
self.right=kwargs.get("right")
#节点的特征
self.feature=kwargs.get("feature")
#节点的分隔维度
self.split=kwargs.get("split")
#判断两个节点是否相等 ,待优化
def equals(self,node):
if(self.feature==node.feature).all():
return True
else:
return False
class KDTree(object):
def __init__(self):
self.root =Node()
#构造kd树的时候,选择方差大的维度,进行分割,使得kd树数据尽量分布均匀
def _choose_feature(self,X,idxs):
m=len(X[0])
variances=map(lambda j:(j,self._get_variance(X,idxs,j)),range(m))
return max(variances,key=lambda x:x[1])[0]
#X是样本集,是个二维数组,idxs是要求方差的样本点的索引,一维数组,dimension是对哪个维度进行
#方差计算
def _get_variance(self,X,idxs,dimension):
n=len(idxs)
col_sum=col_sum_sqrt=0
for idx in idxs:
xi=X[idx][dimension]
col_sum+=xi
col_sum_sqrt+=xi ** 2
## D(X)=E(X^2)-[E(X)]^2
return col_sum_sqrt/n -(col_sum/n) ** 2
#计算输入点point,到nd所在的分离超平面的距离
def _get_hyper_plane_dist(self,point,nd):
#节点的分割维度
j=nd.split
#分割维度上的差到超平面的距离
return abs(point[j]-nd.feature[j])
def search_node(self,point,node):
while node!=None and (node.left or node.right):
if(point.feature[node.split]<node.feature[node.split]):
node=node.left
elif point.feature[node.split]>node.feature[node.split]:
node=node.right
else:
if(node.equals(point)):
break;
return node
#记录节点的查找到叶子节点的搜索路径,以便后面的回溯
def get_search_path(self,point,node):
search_path=[node.feature]
while node!=Node and (node.left or node.right):
if point[node.split]<node.feature[node.split]:
node=node.left
if node!=None:
search_path.append(node.feature)
else:
node=node.right
if node!=None:
search_path.append(node.feature)
return search_path
#记录搜索路径的向左向右的集合,也算是记录路径的一种
def get_search_directions(self,point,node):
search_directions=[]
while node!=Node and (node.left or node.right):
if point[node.split]<node.feature[node.split]:
node=node.left
if node!=None:
search_directions.append(1)
else:
node=node.right
if node!=None:
search_directions.append(0)
return search_directions
#获取样本集X,二维数组,在feature维度,的中值的索引
def _get_median_idx(self,X,idxs,feature):
n=len(idxs)
k=n//2
col=map(lambda i:(i,X[i][feature]),idxs)
sorted_idxs=map(lambda x:x[0],sorted(col,key=lambda x:x[1]))
median_idx=list(sorted_idxs)[k]
return median_idx
def _split_feature(self,X,idxs,feature,median_idx):
idxs_split=[[],[]]
split_val=X[median_idx][feature]
for idx in idxs:
if idx== median_idx:
continue
xi=X[idx][feature]
if xi<split_val:
idxs_split[0].append(idx)
else:
idxs_split[1].append(idx)
return idxs_split
#递归构建kd树
def build_tree(self,points,depth):
if 0 == len(points):
return None
cutting_dim= depth % len(points[0])
medium_index=int(len(points)/2)
pointsNew=sorted(points,key=itemgetter(cutting_dim))
node=Node(feature=pointsNew[medium_index],split=cutting_dim)
node.left=self.build_tree(pointsNew[:medium_index],depth+1)
node.right=self.build_tree(pointsNew[medium_index+1:],depth+1)
return node
#最近邻查找
def nearest_neighbor_search(self,point,tree):
search_path=self.get_search_path(point,tree)
search_directions = self.get_search_directions(point,tree)
#当前最近节点
node_best=search_path.pop()
#回溯节点
back_node=search_path.pop()
#回溯节点方向(最后一步的方向)
back_direction=search_directions.pop()
#一直循环,直到回溯节点为空
while all(back_node)!=None and len(search_path):
#当前里输入点,最近距离
dis_best=_get_euclidean_distance(node_best,point)
#输入点,到超平面的距离
dis_hyper_plane=self._get_hyper_plane_dist(point,Node(feature=back_node,split=back_direction))
if dis_best>= dis_hyper_plane:
node_best=back_node
dis_best=_get_euclidean_distance(node_best,point)
#找到回溯节点
searchedBackNode=self.search_node(Node(feature=back_node,split=back_direction),tree)
if(back_direction==0):
#dis_best>dis_hyper_plane就是输入点到最近点的距离,画圆,和超平面有交点。
#说明另外一个子空间,可能有比当前最近点,更近的。 需要回溯到超平面所在的分隔
#所在的另外一个子空间,去寻找更近点
if searchedBackNode.left:
search_path.append(searchedBackNode.left.feature)
else:
if searchedBackNode.right:
search_path.append(searchedBackNode.right.feature)
if len(search_path)>0:
back_node= search_path.pop()
if len(search_directions)>0:
back_direction=search_directions.pop()
return node_best
tree=KDTree()
features=np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
node=tree.build_tree(features,0)
##print(tree.nearest_neighbor_search([2.1,3.1],node))
print(tree.nearest_neighbor_search([2,4.5],node))
代码待优化,回溯节点的时候,没有考虑周全。
参考文档:
<https://blog.csdn.net/pipisorry/article/details/52186307>
网友评论