美文网首页
深度强化学习(12)DQN - Deep Q Learning

深度强化学习(12)DQN - Deep Q Learning

作者: 数科每日 | 来源:发表于2022-02-17 04:21 被阅读0次

本文主要内容来源于 Berkeley CS285 Deep Reinforcement Learning


深度强化学习(11)DQN - Deep Q Learning (1) 中,我们介绍了 Replay Buffer, 在本节我们将继续讨论 DQN 。

Target Network

Replay Buffer 解决了 i.i.d. 的问题。 根据上文, Q Learing 还有一个问题, 就是 Q Network在训练的时候, 会遇到 Target 变化的问题:

Problem

一个自然而然的想法是, 我们利用另外一个 Neural Network, 如果这个 Neural Network 不变的话,我们就解决了这个问题。 这个 Neural Network, 我们称其为 Target Network 记为\Phi^{\prime} (之前的Network 称作 Current Network 记为 \Phi)。 这样的话, Q-learning 就可以变成 DQN

DQN : Deep Q Learning

其实 Target Network 也需要训练, 但是我们并不会和Current Network 同步训练, 而是在Current Network 训练了 N 步以后, 再更新参数。\Phi^{\prime} 更新参数的方法其实很暴力, 就是直接拷贝 \Phi的参数。严格来说,问题还是存在,只是大部分被这种延迟更新解决了。

伪代码:

# gradient_steps: 在N步以后, 更新Target Network 的参数
for _ in range(gradient_steps):

    with torch.no_grad():
    
        # Step1 生成一些新的 Transation,加入Replay Buffer
        #       这部分不重要, 先跳过
        pass 
            
        # Step2  从 Replay Buffer 中取一些 Sample
        state, action, state_prime, reward = replay_buffer(sample_size)
        
        # Step3  Compute lable by Target Network 
        # 基于 State Prime, 找出所有可能 Action 的 Q value,
        # 假设 ACTION_SPACE 已知, 类型为 list
        Q_of_next_action = []
        for next_action in ACTION_SPACE:
            Q = target_net(state_prime, next_action)
            Q_of_next_action.append(Q)   
            
        # 找出 Q 值最大的 Action , 作为 Action Prime
        new_action_index = argmax(Q_of_next_action)
        action_prime = ACTION_SPACE[action_prime]
        
        # GAMMA is a hyper-prarameter, normally 0.99
        target_q_values = reward + GAMMA * target_net(state_prime, action_prime)
    
    # Step 4 更新 Current Network 参数
    current_q_values = current_net(state)
    loss = smooth_l1_loss(current_q_values, target_q_values)
    
    # Optimize the policy (Current Network)
    policy.optimizer.zero_grad() 
    loss.backward()

# 在 N 步以后,更新 target_net 参数
target_net.param = current_net.param.copy()
在更新 Target Network 的时候, 加入衰减变量 \tau

因为更新网络参数的延迟问题, 有人建议引入权重, 离更新越远的样本, 权重越低(时效性差)。

新的 Update 函数

相关文章

网友评论

      本文标题:深度强化学习(12)DQN - Deep Q Learning

      本文链接:https://www.haomeiwen.com/subject/mpejlrtx.html