美文网首页
堆优化的A*算法-Python实现

堆优化的A*算法-Python实现

作者: 微雨旧时歌丶 | 来源:发表于2020-03-27 10:43 被阅读0次

    堆优化的A*算法-Python实现

    A*算法解决二维网格地图中的寻路问题

    • 输入:图片(白色区域代表可行,深色区域代表不行可行)
    • 输出:路径(在图中绘制)
    """ 方格地图中的A*算法 (openList进行了堆优化)
    A* 算法:  F = G+H
    F: 总移动代价
    G: 起点到当前点的移动代价  直:1, 斜:1.4
    H: 当前点到终点的预估代价  曼哈顿距离
    ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    1.把起点加入 openList中
    2.While True:
        a.遍历openList,查找F值最小的节点,作为current
        b.current是终点:
            ========结束========
        c.从openList中弹出,放入closeList中
        d.对八个方位的格点:
            if 越界 or 是障碍物 or 在closeList中:
                continue
            if 不在openList中:
                设置父节点,F,G,H
                加入openList中
            else:
                if 这条路径更好:
                    设置父节点,F,G
                    更新openList中的对应节点
    3.生成路径path
    +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    堆优化:
        openList:作为最小堆,按F值排序存储坐标 (不更新只增加)
        openDict:坐标:点详细信息 (既更新又增加)
        get_minfNode() 从openList中弹出坐标,去openDict中取点 (但由于不更新只增加,坐标可能冗余)
        in_openList() 判断坐标是否在openDict中即可 
    
    """
    import math
    from PIL import Image,ImageDraw 
    import numpy as np
    import heapq # 堆
    
    STAT_OBSTACLE='#'
    STAT_NORMAL='.'
    
    def manhattan(x1,y1, x2,y2):
        """两个Point的曼哈顿距离"""
        h = abs(x1-x2)+abs(y1-y2)
        return h
    
    class Node():
        """
        开放列表和关闭列表的元素类型,parent用来在成功的时候回溯路径
        """
        def __init__(self, x, y,parent=None, g=0, h=0):
            self.parent = parent
            self.x = x
            self.y = y
            self.g = g
            self.h = h
            self.update()
        
        def update(self):
            self.f = self.g+self.h
    
    
    class A_Star:
        """ x是行索引,y是列索引
        """
        def __init__(self, test_map, start=None, end=None):
            """地图,起点,终点"""
            self.map = test_map
            self.cols = len(test_map[0])
            self.rows = len(test_map)
            self.s_x, self.s_y = start if start else [0,0]
            self.e_x, self.e_y = end if end else [self.rows-1,self.cols-1]
            self.closeList = set()
            self.path = []
            self.openList = []  # 堆,只添加,和弹出最小值点,
            self.openDict = dict() # openList中的 坐标:详细信息 -->不冗余的
            
        
        def find_path(self):
            """A*算法寻路主程序"""
            p = Node(self.s_x, self.s_y, 
                     h=manhattan(self.s_x,self.s_y, self.e_x,self.e_y)) # 构建开始节点
            heapq.heappush(self.openList, (p.f,(p.x,p.y)))
            
            self.openDict[(p.x,p.y)] = p  # 加进dict目录
            while True:
                current = self.get_minfNode()
                if current.x==self.e_x and current.y==self.e_y:
                    print('find path')
                    self.make_path(current)
                    break
                
                self.closeList.add((current.x,current.y))  ## 加入closeList
                del self.openDict[(current.x,current.y)]
                self.extend_surrounds(current) # 会更新close list
    
        def make_path(self,p):
            """从结束点回溯到开始点,开始点的parent==None"""
            while p:
                self.path.append((p.x, p.y))
                p = p.parent
        
        def extend_surrounds(self, node):
            """ 将当前点周围可走的点加到openList中,
                其中 不在openList中的点 设置parent、F,G,H 加进去,
                     在openList中的点  更新parent、F,G,H
            """
            motion_direction = [[1, 0], [0,  1], [-1, 0], [0,  -1], 
                                [1, 1], [1, -1], [-1, 1], [-1, -1]]  
            for dx, dy in motion_direction:
                x,y = node.x+dx, node.y+dy
                new_node = Node(x,y)
                # 位置无效,或者是障碍物, 或者已经在closeList中 
                if not self.is_valid_xy(x,y) or not self.not_obstacle(x,y) or self.in_closeList(new_node): 
                    continue
                if abs(dx)+abs(dy)==2:  ## 斜向
                    h_x,h_y = node.x+dx,node.y # 水平向
                    v_x,v_y = node.x,node.y+dy # 垂直向
                    if not self.is_valid_xy(h_x,h_y) or not self.not_obstacle(h_x,h_y) or self.in_closeList(Node(h_x,h_y)): 
                        continue
                    if not self.is_valid_xy(v_x,v_y) or not self.not_obstacle(v_x,v_y) or self.in_closeList(Node(v_x,v_y)): 
                        continue
                #============ ** 关键 **             ========================
                #============ 不在openList中,加进去; ========================
                #============ 在openList中,更新      ========================
                #============对于openList和openDict来说,操作都一样 ===========
                new_g = node.g + self.cal_deltaG(node.x,node.y, x,y)
                sign=False # 是否执行操作的标志 
                if not self.in_openList(new_node): # 不在openList中
                    # 加进来,设置 父节点, F, G, H
                    new_node.h = self.cal_H(new_node)
                    sign=True
                elif self.openDict[(new_node.x,new_node.y)].g > new_g: # 已在openList中,但现在的路径更好
                    sign=True
                if sign:
                    new_node.parent = node
                    new_node.g = new_g
                    new_node.f = self.cal_F(new_node)
                    self.openDict[(new_node.x,new_node.y)]=new_node # 更新dict目录
                    heapq.heappush(self.openList, (new_node.f,(new_node.x,new_node.y)))
            
        def get_minfNode(self):
            """从openList中取F=G+H值最小的 (堆-O(1))"""
            while True:
                f, best_xy=heapq.heappop(self.openList)
                if best_xy in self.openDict:
                    return self.openDict[best_xy]
    
        def in_closeList(self, node):
            """判断是否在closeList中 (集合-O(1)) """
            return True if (node.x,node.y) in self.closeList else False
         
        def in_openList(self, node):
            """判断是否在openList中 (字典-O(1))"""
            if not (node.x,node.y) in self.openDict:
                return False
            else:
                return True
    
        def is_valid_xy(self, x,y):
            if x < 0 or x >= self.rows or y < 0 or y >= self.cols:
                return False
            return True
            
        def not_obstacle(self,x,y):
            return self.map[x][y] != STAT_OBSTACLE
        
        def cal_deltaG(self,x1,y1,x2,y2):
            """ 计算两点之间行走的代价
                (为简化计算)上下左右直走,代价为1.0,斜走,代价为1.4  G值
            """
            if x1 == x2 or y1 == y2:
                return 1.0
            return 1.4
        
        def cal_H(self, node):
            """ 曼哈顿距离 估计距离目标点的距离"""
            return abs(node.x-self.e_x)+abs(node.y-self.e_y) # 剩余路径的估计长度
        
        def cal_F(self, node):
            """ 计算F值 F = G+H 
                A*算法的精髓:已经消耗的代价G,和预估将要消耗的代价H
            """
            return node.g + node.h
    
    
    def plot(test_map,path):
        """绘制地图和路径
            test_map:二维数组
            path:路径坐标数组
        """
        out = []
        for x in range(len(test_map)):
            temp = []
            for y in range(len(test_map[0])):
                if test_map[x][y]==STAT_OBSTACLE:
                    temp.append(0)
                elif test_map[x][y]==STAT_NORMAL:
                    temp.append(255)
                elif test_map[x][y]=='*':
                    temp.append(127)
                else:
                    temp.append(255)
            out.append(temp)
        for x,y in path:
            out[x][y] = 127
        out = np.array(out)
        img = Image.fromarray(out)
        img.show()
    
    def path_length(path):
        """计算路径长度"""
        l = 0
        for i in range(len(path)-1):
            x1,y1 = path[i]
            x2,y2 = path[i+1]
            if x1 == x2 or y1 == y2:
                l+=1.0
            else:
                l+=1.4
        return l
        
    
    def img_to_map(img_file):
        """地图图片变二维数组"""
        test_map = []
        img = Image.open(img_file)
        img = img.resize((100,100))  ### resize图片尺寸
        img_gray = img.convert('L')  # 地图灰度化
        img_arr = np.array(img_gray)
        img_binary = np.where(img_arr<127,0,255)
        for x in range(img_binary.shape[0]):
            temp_row = []
            for y in range(img_binary.shape[1]):
                status = STAT_OBSTACLE if img_binary[x,y]==0 else STAT_NORMAL 
                temp_row.append(status)
            test_map.append(temp_row)
        
        return test_map
    
    # ===== test case ===============
    test_map=img_to_map('map_2.bmp')
    a = A_Star(test_map)
    a.find_path()
    plot(test_map,a.path)
    print('path length:',path_length(a.path))
    

    测试用例及结果

    map1.png map2.png map5.png

    存在的问题

    不确定是否是最优路径
    原文描述:
    “ If we overestimate this distance, however, it is not guaranteed to give us the shortest path. In such cases, we have what is called an "inadmissible heuristic.".

    Technically, in this example, the Manhattan method is inadmissible because it slightly overestimates the remaining distance.”
    即如果我们高估了H,则不能保证最短路径。而曼哈顿距离略微高估了。

    另外,笔者不确定程序是不是正确,以及是不是真正的A*算法,请大神们指正。

    相关文章

      网友评论

          本文标题:堆优化的A*算法-Python实现

          本文链接:https://www.haomeiwen.com/subject/odtluhtx.html