强化学习基础篇(十二)策略评估算法在FrozenLake中的实现
本节将主要基于gym环境中的FrozenLake-v0进行策略评估算法的实现。
1. 迭代策略评估算法的伪代码
迭代策略评估算法,用于估计
输入待评估的策略
算法参数:小阈值,用于确定估计量的精度
对于任意,任意初始化,其中
循环:
对每一个循环:
直到
2. FrozenLake-v0环境
FrozenLake环境是一个GridWorld环境,名字是指在一块冰面上有四种state:
S: initial stat 起点
F: frozen lake 冰湖
H: hole 窟窿
G: the goal 目的地
智能体要学会从起点走到目的地,并且不要掉进窟窿。
FrozenLake-v0.gif首先我们调用 FrozenLake-v0环境:
# 导入库信息
import numpy as np
import gym
# 调用环境
env=gym.make("FrozenLake-v0")
环境可视化
# 查看当前状态
env.render()
运行结果为:
SFFF
FHFH
FFFH
HFFG
查看环境的观测空间:
# 查看观测空间
print(env.observation_space,env.nS)
运行结果为:
Discrete(16) 16
查看环境的动作空间:
# 查看动作空间
print(env.action_space,env.nA)
运行结果为:
Discrete(4) 4
动作的定义为:
LEFT = 0
DOWN = 1
RIGHT = 2
UP = 3
转移概率
使用动态规划算法需要直到环境的所有信息,即转移概率,可以通过env.P查看环境的所有转移概率:
P[][]本质上是一个“二维数组”,状态和动作分别由数字0-15和0-3表示。存储的是,在状态s下采取动作a获得的一系列数据,即(转移概率,下一步状态,奖励,完成标志)这样的元组。
# 查看环境转移矩阵
print(env.P)
运行结果为:
{
0: {
0: [(0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 4, 0.0, False)],
1: [(0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 4, 0.0, False), (0.3333333333333333, 1, 0.0, False)],
2: [(0.3333333333333333, 4, 0.0, False), (0.3333333333333333, 1, 0.0, False), (0.3333333333333333, 0, 0.0, False)],
3: [(0.3333333333333333, 1, 0.0, False), (0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 0, 0.0, False)]
},
1: {
0: [(0.3333333333333333, 1, 0.0, False), (0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 5, 0.0, True)],
1: [(0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 2, 0.0, False)],
2: [(0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 1, 0.0, False)],
3: [(0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 1, 0.0, False), (0.3333333333333333, 0, 0.0, False)]
},
2: {
0: [(0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 1, 0.0, False), (0.3333333333333333, 6, 0.0, False)],
1: [(0.3333333333333333, 1, 0.0, False), (0.3333333333333333, 6, 0.0, False), (0.3333333333333333, 3, 0.0, False)],
2: [(0.3333333333333333, 6, 0.0, False), (0.3333333333333333, 3, 0.0, False), (0.3333333333333333, 2, 0.0, False)],
3: [(0.3333333333333333, 3, 0.0, False), (0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 1, 0.0, False)]
},
3: {
0: [(0.3333333333333333, 3, 0.0, False), (0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 7, 0.0, True)],
1: [(0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 7, 0.0, True), (0.3333333333333333, 3, 0.0, False)],
2: [(0.3333333333333333, 7, 0.0, True), (0.3333333333333333, 3, 0.0, False), (0.3333333333333333, 3, 0.0, False)],
3: [(0.3333333333333333, 3, 0.0, False), (0.3333333333333333, 3, 0.0, False), (0.3333333333333333, 2, 0.0, False)]
},
4: {
0: [(0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 4, 0.0, False), (0.3333333333333333, 8, 0.0, False)],
1: [(0.3333333333333333, 4, 0.0, False), (0.3333333333333333, 8, 0.0, False), (0.3333333333333333, 5, 0.0, True)],
2: [(0.3333333333333333, 8, 0.0, False), (0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 0, 0.0, False)],
3: [(0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 0, 0.0, False), (0.3333333333333333, 4, 0.0, False)]
},
5: {
0: [(1.0, 5, 0, True)],
1: [(1.0, 5, 0, True)],
2: [(1.0, 5, 0, True)],
3: [(1.0, 5, 0, True)]
},
6: {
0: [(0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 10, 0.0, False)],
1: [(0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 7, 0.0, True)],
2: [(0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 7, 0.0, True), (0.3333333333333333, 2, 0.0, False)],
3: [(0.3333333333333333, 7, 0.0, True), (0.3333333333333333, 2, 0.0, False), (0.3333333333333333, 5, 0.0, True)]
},
7: {
0: [(1.0, 7, 0, True)],
1: [(1.0, 7, 0, True)],
2: [(1.0, 7, 0, True)],
3: [(1.0, 7, 0, True)]
},
8: {
0: [(0.3333333333333333, 4, 0.0, False), (0.3333333333333333, 8, 0.0, False), (0.3333333333333333, 12, 0.0, True)],
1: [(0.3333333333333333, 8, 0.0, False), (0.3333333333333333, 12, 0.0, True), (0.3333333333333333, 9, 0.0, False)],
2: [(0.3333333333333333, 12, 0.0, True), (0.3333333333333333, 9, 0.0, False), (0.3333333333333333, 4, 0.0, False)],
3: [(0.3333333333333333, 9, 0.0, False), (0.3333333333333333, 4, 0.0, False), (0.3333333333333333, 8, 0.0, False)]
},
9: {
0: [(0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 8, 0.0, False), (0.3333333333333333, 13, 0.0, False)],
1: [(0.3333333333333333, 8, 0.0, False), (0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 10, 0.0, False)],
2: [(0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 5, 0.0, True)],
3: [(0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 5, 0.0, True), (0.3333333333333333, 8, 0.0, False)]
},
10: {
0: [(0.3333333333333333, 6, 0.0, False), (0.3333333333333333, 9, 0.0, False), (0.3333333333333333, 14, 0.0, False)],
1: [(0.3333333333333333, 9, 0.0, False), (0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 11, 0.0, True)],
2: [(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 11, 0.0, True), (0.3333333333333333, 6, 0.0, False)],
3: [(0.3333333333333333, 11, 0.0, True), (0.3333333333333333, 6, 0.0, False), (0.3333333333333333, 9, 0.0, False)]
},
11: {
0: [(1.0, 11, 0, True)],
1: [(1.0, 11, 0, True)],
2: [(1.0, 11, 0, True)],
3: [(1.0, 11, 0, True)]
},
12: {
0: [(1.0, 12, 0, True)],
1: [(1.0, 12, 0, True)],
2: [(1.0, 12, 0, True)],
3: [(1.0, 12, 0, True)]
},
13: {
0: [(0.3333333333333333, 9, 0.0, False), (0.3333333333333333, 12, 0.0, True), (0.3333333333333333, 13, 0.0, False)],
1: [(0.3333333333333333, 12, 0.0, True), (0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False)],
2: [(0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 9, 0.0, False)],
3: [(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 9, 0.0, False), (0.3333333333333333, 12, 0.0, True)]
},
14: {
0: [(0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False)],
1: [(0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True)],
2: [(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False)],
3: [(0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False)]
},
15: {
0: [(1.0, 15, 0, True)],
1: [(1.0, 15, 0, True)],
2: [(1.0, 15, 0, True)],
3: [(1.0, 15, 0, True)]
}
}
3.策略评估源代码
import numpy as np
import gym
def policy_eval(enviroment,policy,discount_factor=1.0,theta=0.1):
# 引用环境
env = enviroment
# 初始化值函数
V = np.zeros(env.nS)
# 开始迭代
for _ in range(500):
delta = 0
# 扫描所有状态
for s in range(env.nS):
v=0
# 扫描动作空间
for a,action_prob in enumerate(policy[s]):
# 扫描下一状态
for prob,next_state,reward,done in env.P[s][a]:
# 更新值函数
v += action_prob * prob * ( reward + discount_factor * V[next_state])
# 更新最大的误差值
delta=max(delta,np.abs(v-V[s]))
V[s] =v
if delta < theta:
break
return np.array(V)
# 定义策略生成函数
def generate_policy(env,input_policy):
policy=np.zeros([env.nS,env.nA])
for _ , x in enumerate(input_policy):
policy[_][x] = 1
return policy
if __name__=="__main__":
# 创建环境
env=gym.make("FrozenLake-v0")
# 定义动作策略
input_policy=[2,1,2,3,2,0,2,0,1,2,2,0,0,1,1,0] # 定义了在每个状态采取的动作,LEFT = 0、DOWN = 1、RIGHT = 2、UP = 3
# 生成策略
policy=generate_policy(env,input_policy)
Value=policy_eval(env,policy)
print("This is the final value:\n")
print(Value.reshape([4,4]))
运行结果为:
This is the final value:
[[0. 0. 0. 0. ]
[0. 0. 0.03703704 0. ]
[0. 0.07407407 0.17283951 0. ]
[0. 0.19753086 0.55967078 0. ]]
4. 代码解析
首先我们会定义策略生成函数
# 定义策略生成函数
def generate_policy(env,input_policy):
policy=np.zeros([env.nS,env.nA])
for _ , x in enumerate(input_policy):
policy[_][x] = 1
return policy
该函数会生成一个[env.nS,env.nA]大小的数组,然后根据输入的每个状态的策略生成一个矩阵,将该状态的某状态置为1。
例如这里我们要评估策略:
input_policy=[2,1,2,3,2,0,2,0,1,2,2,0,0,1,1,0] # 定义了在每个状态采取的动作,LEFT = 0、DOWN = 1、RIGHT = 2、UP = 3
生成的策略矩阵如下所示:
array([[0., 0., 1., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 1., 0., 0.],
[1., 0., 0., 0.]])
在迭代过程中完全按照公式进行。
网友评论