目的
本文介绍一个简单的案例。目标如下:
![](https://img.haomeiwen.com/i13326502/fb5eed317326d76f.png)
我们的机器人在上面的地图上行走,目的是要找到宝藏,如果进入骷髅头,游戏就失败。
定义状态空间
我们对地图上的小方格进行编号: 1 - 8
![](https://img.haomeiwen.com/i13326502/c872e5cb3dea46a1.png)
代码简单表示如下:
self.states = [1,2,3,4,5,6,7,8] #状态空间
同时定义终止状态:
self.terminate_states = dict() #终止状态为字典格式
self.terminate_states[6] = 1
self.terminate_states[7] = 1
self.terminate_states[8] = 1
定义动作空间
动作空间很简单:上,下,左,右
self.actions = ['n','e','s','w']
定义回报函数
self.rewards = dict(); #回报的数据结构为字典
self.rewards['1_s'] = -1.0
self.rewards['3_s'] = 1.0
self.rewards['5_s'] = -1.0
构造我们的环境,目的如下:
![](https://img.haomeiwen.com/i13326502/0736333928e5a901.png)
(如果不知道怎么用 gym 画图,请跳转:https://www.jianshu.com/p/b3c4d2b95c58)
代码如下:
def render(self, mode='human', close=False):
# 创建地图
lines = []
line1 = rendering.Line((0, 0), (0, 200))
lines.append(line1)
line2 = rendering.Line((120, 0), (120, 200))
lines.append(line2)
line3 = rendering.Line((0, 200), (600, 200))
lines.append(line3)
line4 = rendering.Line((0, 100), (600, 100))
lines.append(line4)
line5 = rendering.Line((240, 0), (240, 200))
lines.append(line5)
line6 = rendering.Line((360, 0), (360, 200))
lines.append(line6)
line7 = rendering.Line((480, 0), (480, 200))
lines.append(line7)
line8 = rendering.Line((600, 0), (600, 200))
lines.append(line8)
line9 = rendering.Line((0, 0), (120, 0))
lines.append(line9)
line10 = rendering.Line((240, 0), (360, 0))
lines.append(line10)
line11 = rendering.Line((480, 0), (600, 0))
lines.append(line11)
# 创建骷髅
kulos = []
kulo1 = rendering.make_circle(40)
kulo1_transiform = rendering.Transform(translation=(60, 50))
kulo1.add_attr(kulo1_transiform)
kulo1.set_color(0, 0, 0)
kulos.append(kulo1)
kulo2 = rendering.make_circle(40)
kulo2_transiform = rendering.Transform(translation=(540, 50))
kulo2.add_attr(kulo2_transiform)
kulo2.set_color(0, 0, 0)
kulos.append(kulo2)
# 创建宝藏
golds = []
gold = rendering.make_circle(40)
circletrans = rendering.Transform(translation=(300, 50))
gold.add_attr(circletrans)
gold.set_color(1, 0.9, 0)
golds.append(gold)
# 创建机器人
robots = []
robot = rendering.make_circle(40)
robot.set_color(0.8, 0.6, 0.4)
robot_transilation = rendering.Transform(translation=(420, 150))
robot.add_attr(robot_transilation)
robots.append(robot)
transiform = rendering.Transform(translation=(50, 50))
for line_ in lines:
line_.set_color(0, 0, 0)
line_.add_attr(transiform)
self.viewer.add_geom(line_)
for kulo_ in kulos:
kulo_.add_attr(transiform)
self.viewer.add_geom(kulo_)
for gold_ in golds:
gold_.add_attr(transiform)
self.viewer.add_geom(gold_)
for robot_ in robots:
robot_.add_attr(transiform)
self.viewer.add_geom(robot_)
return self.viewer.render(return_rgb_array=mode == 'rgb_array')
结果
本例子用的算法是DQN, 读者可以先不管DQN是什么
![](https://img.haomeiwen.com/i13326502/e1ab0dd332e3f884.gif)
网友评论