本文主要内容来源于 Berkeley CS285 Deep Reinforcement Learning
在 深度强化学习(11)DQN - Deep Q Learning (1) 中,我们介绍了 Replay Buffer, 在本节我们将继续讨论 DQN 。
Target Network
Replay Buffer 解决了 i.i.d. 的问题。 根据上文, Q Learing 还有一个问题, 就是 Q Network在训练的时候, 会遇到 Target 变化的问题:
![](https://img.haomeiwen.com/i25067830/a48aacff7f6def67.png)
一个自然而然的想法是, 我们利用另外一个 Neural Network, 如果这个 Neural Network 不变的话,我们就解决了这个问题。 这个 Neural Network, 我们称其为 Target Network 记为 (之前的Network 称作 Current Network 记为
)。 这样的话, Q-learning 就可以变成 DQN:
![](https://img.haomeiwen.com/i25067830/5b855a2fc812fe7c.png)
其实 Target Network 也需要训练, 但是我们并不会和Current Network 同步训练, 而是在Current Network 训练了 N 步以后, 再更新参数。 更新参数的方法其实很暴力, 就是直接拷贝
的参数。严格来说,问题还是存在,只是大部分被这种延迟更新解决了。
伪代码:
# gradient_steps: 在N步以后, 更新Target Network 的参数
for _ in range(gradient_steps):
with torch.no_grad():
# Step1 生成一些新的 Transation,加入Replay Buffer
# 这部分不重要, 先跳过
pass
# Step2 从 Replay Buffer 中取一些 Sample
state, action, state_prime, reward = replay_buffer(sample_size)
# Step3 Compute lable by Target Network
# 基于 State Prime, 找出所有可能 Action 的 Q value,
# 假设 ACTION_SPACE 已知, 类型为 list
Q_of_next_action = []
for next_action in ACTION_SPACE:
Q = target_net(state_prime, next_action)
Q_of_next_action.append(Q)
# 找出 Q 值最大的 Action , 作为 Action Prime
new_action_index = argmax(Q_of_next_action)
action_prime = ACTION_SPACE[action_prime]
# GAMMA is a hyper-prarameter, normally 0.99
target_q_values = reward + GAMMA * target_net(state_prime, action_prime)
# Step 4 更新 Current Network 参数
current_q_values = current_net(state)
loss = smooth_l1_loss(current_q_values, target_q_values)
# Optimize the policy (Current Network)
policy.optimizer.zero_grad()
loss.backward()
# 在 N 步以后,更新 target_net 参数
target_net.param = current_net.param.copy()
在更新 Target Network 的时候, 加入衰减变量
因为更新网络参数的延迟问题, 有人建议引入权重, 离更新越远的样本, 权重越低(时效性差)。
![](https://img.haomeiwen.com/i25067830/1a7a879483b11048.png)
网友评论