GitHub - LittleFox99/B_A_star: Bidirectional A Star
其中a,b分别为双向A*搜索的权重
#-*- coding:utf-8 -*-
# @Time : 2020/11/11 1:21 下午
# @Author : LittleFox99
# File : a_star.py
# 参考:
# https://blog.csdn.net/lustyoung/article/details/105027607
# https://www.jianshu.com/p/5704e67f40aa
import numpy as np
import queue
from queue import PriorityQueue
from pylab import *
import matplotlib.pyplot as pltclass A_star:def __init__(self, a, b, start, end, max_row, max_col, barrier_map):"""A_star 初始化:param a::param b::param start::param end::param max_row::param max_col::param barrier_map:"""# 权重参数self.a = aself.b = b# 起始点、终点self.start = startself.end = end# 相遇点self.stop_point = Noneself.stop_point_back = None#采用优先队列实现的小顶堆,用于存放待扩展结点,同时利用f值作为排序指标;self.opened_list1 = PriorityQueue()self.opened_list2 = PriorityQueue()self.open_all_list = []#储存记录所有走过的openlistself.open_list = set()#采用set(红黑树)实现,便于快速查找当前point是否存在于closed_list中;self.closed_list1 = set()self.closed_list2 = set()# 地图边界self.max_row = max_rowself.max_col = max_col# 论文中设定为四个移动方向self.direction = [[0, 1], [0, -1], [1, 0], [-1, 0]]# 障碍地图self.barrier_map = barrier_map#遍历结点数self.travel_num = 0def bound(self, row, col):"""边界条件检测:1.如果该点坐标超出地图区域,则非法2.如果该点是障碍,则非法:param row::param col::return:"""if row < 0 or row > self.max_row:return Trueif col < 0 or col > self.max_col:return Trueif [row, col] in self.barrier_map:return Truereturn Falsedef euclidean_distance(self, point1, point2):"""欧式距离计算:param point1::param point2::return:"""x_ = abs(point1.getx() - point2.getx())y_ = abs(point1.gety() - point2.gety())return ((x_**2)+(y_**2))**(0.5)*1.0def find_path(self, point):"""用栈回溯查找B-A*给出的路径:param point::return:"""stack = []father_point = pointwhile father_point is not None:stack.append(father_point.get_coordinate())father_point = father_point.fatherreturn stackdef heuristic_h(self, end_point, point1, point2):"""启发函数h值计算h:正向搜索当前点n1到终点的欧式距离h_:正向搜索当前点n1到反向搜索当前点n2的距离:param end_point::param point1::param point2::return:"""h = self.euclidean_distance(point1, end_point)h_ = self.euclidean_distance(point1, point2)return (1 - self.b) * (h*1.0 / (1 + self.a) + self.a * h_*1.0 / (1 + self.a))def heuristic_g(self, point):"""启发函数g值计算g:当前点父亲的g值g_:当前点到其父亲的欧式距离:param point::return:"""g = point.father.gg_ = self.euclidean_distance(point, point.father)return self.b * (g + g_)*1.0def heuristic_f(self, h, g):# 计算启发式函数f值return h + gdef compute_child_node(self, current_point, back_current_point, search_forward):"""遍历当前点的子节点:param current_point: 当前点:param back_current_point: 相反方向搜索的当前点:param search_forward: 搜索方向设置,True为正向,反之亦然:return:"""# 四个方向的遍历for direct in self.direction:col_index = current_point.gety() + direct[0]row_index = current_point.getx() + direct[1]# 创建子节点,将当前点设置为子节点的父节点child_point = Point(row_index, col_index, current_point)# 查找open_list中是否存在子节点中,用于备份节点old_child_point = None#前向搜索if search_forward == True:#计算子节点各个启发函数值child_point.h = self.heuristic_h(self.end, current_point, back_current_point)child_point.g = self.heuristic_g(child_point)child_point.f = self.heuristic_f(child_point.h, child_point.g)# 边界检测, 如果它不可通过或者已经在关闭列表中, 跳过if (row_index, col_index) in self.closed_list1 or self.bound(row_index, col_index)==True:continueelse:self.travel_num = self.travel_num + 1# 通过检测则将子节点加入openlist(记录所有加入过openlist的点)self.open_list.add(child_point.get_coordinate())# 找到最短路径if (row_index, col_index) in self.closed_list2:self.stop_point_back = self.search_back_point(child_point, self.open_all_list)self.stop_point = child_pointif self.stop_point_back is None:self.stop_point_back = self.stop_pointprint("forward!")return True"""如果可到达的节点存在于OPEN1列表中,称该节点为x点,计算经过n1点到达x点的g1(x)值,如果该g1(x)值小于原g1(x)值,则将n1点作为x点的父节点,更新OPEN1列表中的f1(x)和g1(x)。"""tmp_queue = [] # opened_list的备份while not self.opened_list1.empty():tmp_point = self.opened_list1.get()if child_point.get_coordinate() == tmp_point.get_coordinate(): #找到x点,跳过old_child_point = tmp_pointcontinuetmp_queue.append(tmp_point)while len(tmp_queue) != 0:self.opened_list1.put(tmp_queue.pop())if old_child_point is None: #如果没找到,直接加入子节点self.opened_list1.put(child_point)else:# 找到x点# 用g值为参考检查新的路径是否更好。更低的g值意味着更好的路径。if old_child_point.g > child_point.g:self.opened_list1.put(child_point)else:# print(2)self.opened_list1.put(old_child_point)# print(2)# 反向搜索,同理如上else:# 边界检测, 如果它不可通过或者已经在关闭列表中, 跳过child_point.h = self.heuristic_h(self.start, current_point, back_current_point)child_point.g = self.heuristic_g(child_point)child_point.f = self.heuristic_f(child_point.h, child_point.g)if (row_index, col_index) in self.closed_list2 or self.bound(row_index, col_index):continueelse:self.travel_num = self.travel_num + 1self.open_list.add(child_point.get_coordinate())# 找到最短路径if (row_index, col_index) in self.closed_list1:# self.stop_point_back = self.search_father_point(child_point,self.opened_list2)self.stop_point_back = self.search_back_point(child_point, self.open_all_list)self.stop_point = child_pointif self.stop_point_back is None:self.stop_point_back = self.stop_pointprint("backward!")return Truetmp_queue = []while not self.opened_list2.empty():tmp_point = self.opened_list2.get()if child_point.get_coordinate() == tmp_point.get_coordinate():old_child_point = tmp_pointcontinuetmp_queue.append(tmp_point)while len(tmp_queue) != 0:self.opened_list2.put(tmp_queue.pop())if old_child_point is None: #open_list没有找到子节点,则将self.opened_list2.put(child_point)else:# 用g值为参考检查新的路径是否更好。更低的G值意味着更好的路径。if old_child_point.g > child_point.g:self.opened_list2.put(child_point)else:self.opened_list2.put(old_child_point)# print(3)return Falsedef search_back_point(self, point, opened_list):back_point = Nonewhile len(opened_list)!=0:tmp_point = opened_list.pop()if point.get_coordinate() == tmp_point.get_coordinate():return tmp_pointreturn back_pointdef search(self):# 将起始点s设置为正向当前结点n1、终点e设置为反向当前结点n2current_point1 = self.startcurrent_point2 = self.end# 并加入open1、open2self.opened_list1.put(current_point1)self.opened_list2.put(current_point2)forward_path, backward_path =None, None# opened_list1与opened_list2全部非空,输出寻路提示失败find_stop = Falsemin_f_point1 = self.opened_list1.get()self.closed_list1.add((min_f_point1.getx(), min_f_point1.gety()))min_f_point2 = self.opened_list2.get()self.closed_list2.add((min_f_point2.getx(), min_f_point2.gety()))while True:# # 取出open_list1中f值最小的点,加入closed_list1# 将其作为当前结点,遍历寻找它的子节点# min_f_point1 = self.opened_list1.get()# self.closed_list1.add((min_f_point1.getx(), min_f_point1.gety()))self.open_all_list.append(current_point1)self.open_all_list.append(current_point2)find_stop = self.compute_child_node(current_point1, current_point2, True)if find_stop:# forward_path = self.find_path(current_point1)forward_path = self.find_path(self.stop_point)backward_path = self.find_path(self.stop_point_back)break# min_f_point1 = self.opened_list1.get()# print(1)min_f_point1 = self.opened_list1.get()# self.open_all_list.append(min_f_point1)self.closed_list1.add((min_f_point1.getx(), min_f_point1.gety()))current_point1 = min_f_point1self.open_all_list.append(current_point1)# 取出open_list1中f值最小的点,加入closed_list1# current_point1 = self.opened_list1.get()find_stop = self.compute_child_node(current_point2, current_point1, False)if find_stop:forward_path = self.find_path(self.stop_point)backward_path = self.find_path(self.stop_point_back)# backward_path = self.find_path(self.stop_point)break# min_f_point2 = self.opened_list2.get()min_f_point2 = self.opened_list2.get()# self.open_all_list.append(min_f_point1)self.closed_list2.add((min_f_point2.getx(), min_f_point2.gety()))current_point2 = min_f_point2if self.opened_list1.qsize() == 0 or self.opened_list2.qsize() == 0:break# current_point2 = self.opened_list2.get()if backward_path==None and forward_path==None:print("Fail to find the path!")return Noneelse:forward_path = forward_path + backward_pathreturn forward_path, self.open_list, self.stop_point, self.travel_numclass Point:"""Point——地图上的格子,或者理解为点1.坐标2.g,h,f,father"""def __init__(self, x, y, father):self.x = xself.y = yself.g = 0self.h = 0self.f = 0self.father = father# 用于优先队列中f值的排序def __lt__(self, other):return self.f < other.f# 获取x坐标def getx(self):return self.x# 获取y坐标def gety(self):return self.y# 获取x, y坐标def get_coordinate(self):return self.x, self.ydef draw_path(max_row, max_col, barrier_map, path=None, openlist=None, stop_point=None, start=None ,end=None):"""画出B-A*算法的结果模拟地图:param max_row:地图x最大距离:param max_col:地图y最大距离:param barrier_map:障碍地图:param path:B-A*给出的较优路径:param openlist:B-A*中的openlist:param stop_point:B-A*中的相遇结点:param start:起始点:param end:终点:return:""""""划分数组的x,y坐标"""barrier_map_x = [i[0] for i in barrier_map]barrier_map_y = [i[1] for i in barrier_map]path_x = [i[0] for i in path]path_y = [i[1] for i in path]open_list_x = [i[0] for i in openlist]open_list_y = [i[1] for i in openlist]"""对画布进行属性设置"""# plt.subplot(2, 2, subplot)plt.figure(figsize=(15, 15)) # 为了防止x,y轴间隔不一样长,影响最后的表现效果,所以手动设定等长plt.xlim(-1, max_row)plt.ylim(-1, max_col)my_x_ticks = np.arange(0, max_row, 1)my_y_ticks = np.arange(0, max_col, 1)plt.xticks(my_x_ticks) # 竖线的位置与间隔plt.yticks(my_y_ticks)plt.grid(True) # 开启栅格"""画图"""plt.scatter(barrier_map_x, barrier_map_y, s=500, c='k', marker='s')plt.scatter(open_list_x, open_list_y, s=500, c='cyan', marker='s')plt.scatter(path_x, path_y,s=500, c='r', marker='s')plt.scatter(stop_point.getx(), stop_point.gety(), s=500, c='g', marker='s')plt.scatter(start.getx(),start.gety(),s=500, c='b', marker='s')plt.scatter(end.getx(), end.gety(), s=500, c='b', marker='s')plt.title("Bidirectiional A Star , a = {}, b = {}".format(a, b))print(path)plt.savefig("result_pic/a_{},b_{}.png".format(a, b))plt.show()def draw_arg(travel_num, path_num, a_list, b_list):markes = ['-o', '-s', '-^', '-p', '-^','-v']travel_num = np.array(travel_num)path_num = np.array(path_num)fig,ax = plt.subplots(2, 1, figsize=(15,12))for i in range(0, 6):ax[0].plot(a_list, travel_num[i*6:(i+1)*6], markes[i],label=b_list[i])ax[0].set_ylabel('Travel num')ax[1].plot(a_list, path_num[i*6:(i+1)*6], markes[i],label=b_list[i])ax[1].set_xlabel('The values of a')ax[1].set_ylabel('Path nun')ax[0].set_title('the two args of B-A*')box = ax[1].get_position()ax[1].set_position([box.x0, box.y0, box.width, box.height])ax[1].legend(loc='right', bbox_to_anchor=(1.1, 0.6), ncol=1, title="b")plt.savefig("ab.png")plt.show()if __name__ == '__main__':"""权重的值"""a_list = [2, 3, 4, 5, 6, 7]b_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]"""地图、障碍物的位置、起始点、终点"""max_row = 30max_col = 30# barrier_map = [[1, 2], [3, 5], [4,7],[6, 6], [11, 19], [11, 6],[11, 5],[12,5],[12,6],[11, 16], [11, 17], [11, 18], [7, 7]]barrier_map = [[6, 29], [7, 29], [8, 29], [6, 28], [7, 28], [8, 28], [6, 27], [7, 27], [8, 27], [6, 26], [7, 26],[8, 26], [6, 25], [7, 25], [8, 25], [6, 24], [7, 24], [8, 24], [29, 25], [28, 25], [27, 25],[26, 25], [25, 25], [24, 25], [23, 25], [22, 25], [21, 25], [29, 24], [28, 24], [27, 24], [26, 24],[24, 24], [25, 24],[23, 24], [22, 24], [21, 24], [20, 10], [21, 10], [22, 10], [23, 10], [24, 10], [20, 9], [21, 9],[22, 9], [23, 9], [24, 9], [20, 8], [21, 8], [22, 8], [23, 8], [24, 8], [20, 7], [21, 7], [22, 7],[23, 7], [24, 7], [20, 6], [21, 6], [22, 6], [23, 6], [24, 6], [20, 5], [21, 5], [22, 5], [23, 5],[24, 5], [20, 4], [21, 4], [22, 4], [23, 4], [24, 4], [20, 3], [21, 3], [22, 3], [23, 3], [24, 3],[20, 2], [21, 2], [22, 2], [23, 2], [24, 2], [20, 1], [21, 1], [22, 1], [23, 1], [24, 1], [20, 0],[21, 0], [22, 0], [23, 0], [24, 0],[16, 16], [16, 17], [16, 18], [17, 16], [17, 17], [17, 18], [18, 16], [18, 17], [18, 18], [19, 16],[19, 17], [19, 18], [20, 16], [20, 17], [20, 18], [21, 16], [21, 17], [21, 18], [22, 16], [22, 17],[22, 18], [23, 16], [23, 17], [23, 18], [24, 16], [24, 17], [24, 18], [25, 16], [25, 17], [25, 18],[26, 16], [26, 17], [26, 18], [27, 16], [27, 17], [27, 18], [28, 16], [28, 17], [28, 18], [29, 16],[29, 17], [29, 18]]# barrier_map = np.array(barrier_map)# start = Point(4, 5, None)# end = Point(18, 8, None)start = Point(27, 2, None)end = Point(0, 29, None)travel_point_num = []path_point_num = []"""遍历权重的值,使用B-A*算法"""# for a, b in zip(a_list, b_list):for b in b_list:for a in a_list:print(a,b)path, open_list, stop_point, tp_num = A_star(a, b, start, end, max_row, max_col,barrier_map).search()if path is not None:# 找到路径print("Find the path!")travel_point_num.append(tp_num)path_point_num.append(len(path))draw_path(max_row, max_col, barrier_map, path, open_list, stop_point, start, end)else:print("Fail to find the path!")print(travel_point_num)print(path_point_num)draw_arg(travel_point_num, path_point_num, a_list, b_list)