美文网首页
DQN入坑教程

DQN入坑教程

作者: 社交达人叔本华 | 来源:发表于2022-10-29 21:12 被阅读0次

    最近在自学DQN,主要参考了Pytorch 上的这个DQN教程。本文首先从直观上介绍一下DQN是什么,然后参考教程给出CartPole 这个例子的介绍,利用这个例子本文将会对DQN进行一个讲解,最后会罗列一下在学教程的过程中学到的一些编程的骚操作。

    1. 什么是DQN?

    简单来说,DQN是将强化学习中经典算法QLearning 和深度学习的技术相结合,是QLearning 在“新时代”的发展哈哈哈。本小节将会从强化学习开始讲起,然后介绍一下经典的Qlearning 算法,最后会介绍DQN的基本理念和核心要点。

    1.1 什么是强化学习(简单的背景知识小补充)?

    强化学习,是机器学习的一种,主要用来解决需要连续决策的情况(sequential decision making),比如操作一个机器人走迷宫需要不断告诉机器人每一步的操作应该是什么。所以如果能够把某一个task, map成一个连续决策的问题,那么这个task就可以采用强化学习的方法来解决。

    通常,我们把强化学习的方法拆解成两个部分:Agent和Environment。我们认为Agent通过和Environment 进行交互获得一些关于environment的有价值的信息,从而能够做出更好的决策。交互的越多,获得信息越多,就能做出更好的选择,这个过程就是learning。

    下面我们呢,用一些稍微正式一点的符号来解释一下强化学习。我们认为一个强化学习的过程应该包括:

    • 初始时,environment 处于某一个状态s_0 ,然后这个状态体现为观察值w_0 (很多时候,我们并不知道环境处于什么样的状态,只能从一些观察值推理出环境的状态信息) 。
    • 在观察到观察值w_0 之后,agent 做出一个判断认为环境处于状态s_x (不一定是s_0,谁知道这个agent行不行呢,它能不能推断出正确的状态呢)。然后这个agent根据某一个自己的policy,结合判断出来的状态,执行某一个对自己最有利的action
    • 在agent执行完action之后,就改造了environment!使得environment 发生了状态的改变,怎么改呢?是根据一个叫做transition function的东西进行转换的 。这个函数给定一个初始状态和转换后的状态,以及执行的action, 该函数可以给出进行这样转换的概率。T: S \times A \times S \rightarrow [0,1] 在完成这样的状态转换之后,环境会给于agent一个对应的rewards。具体是什么样的reward,是取决于reward函数的。给定一个初始状态和转换后的状态以及执行的action,该函数可以给出进行该action的Reward。R:S*A*S->R
    • 重复上面步骤直到终止状态。


      Screen Shot 2022-10-28 at 16.56.17.png

    这就是强化学习稍微正式一点的过程解释。在实际的运用中,我们通常会对这个过程在简化一点点,简化成马尔科夫决策过程(Markov Decision Process)。其实简化的地方很少,就是我们认为不存在观察值这么一说,就是我们直接能观察到状态(老厉害了我们),其他过程都不会发生变化。具体可以参考下面这张图。

    Screen Shot 2022-10-28 at 16.56.28.png

    有了上面的解释,我们就能够把一个实际的问题,转换强化学习的问题。那么现在我们想要知道,强化学习又是怎么学习的呢, 我们怎么一步一步的提高这个强化学习的方法的呢?

    首先,我们回顾一下刚刚讲的那个强化学习的过程(没错我就是要回顾,明明刚刚讲过?明明上上一段讲的内容?哎对,我就是要再讲一下,怎么滴)。我记得我初学的时候看到这个过程,就很困惑。什么?这不是一切都给你了,状态怎么转移你知道了,reward怎么计算你也知道了,agent如何做决策你也知道了,那不就计算一下就能知道每一步的决策应该是什么了吗?如果你也有类似的疑问,那答案就是我们其实并不是知道所有的这些函数的,是需要我们自己从数据中或者从和环境的不断交互中获取到充分的信息,不断的学习才能得到的,比如transition function。

    强化学习的应用场景其实很多,有的应用场景呢是知道transition function 的,有的是不知道。知道transition function的呢,我们发明了一类强化学习的方法来解决这类问题,这些强化学习方法就叫model-based reinforcement learning. 这一类方法呢主要是非深度学习的方法,比如动态规划(dynamic programming)。另外那些不知道transition function的应用场景呢,我们需要从数据中显式或者隐式的学习到这个transition function,我们同样也有一类专门的强化学习方法用来解决这类问题,这些强化学习的方法被称为model-free reinforcement learning。这一类方法呢又分成了value-based methods 和policy based methods。 本文的重点DQN 就属于典型的value-based 的方法。


    Screen Shot 2022-10-28 at 16.45.49.png

    value based 方法其实是非常大的一个类别,上面这张图主要目的是为了告诉大家DQN在整个领域里的什么位置,所以并没有罗列其他的value-based methods。 那么,什么是value-based methods呢。我们都已经知道了,强化学习的目的就是为了能够更好的连续决策,那么我们直观的想一想,我们如何去做这样的决策。我们决策的依据应该是我们总是希望能够做出有利于自己的决策,有利可图的决策,或者利益(长远利益)最大的决策。决策这个词也有点宽泛,就是我们想要采用利益最大的action。所以,咚咚咚咚咚(drum roll),哎!我们如果给每个action在每个状态都打个分,按照他们的价值排个序,不是就可以了,就选价值最大的action就行了!

    那么,问题来了,我们想要怎么打分?什么样的action的价值比较高?当然是能带来的长远利益比较多的action价值比较高啦(也没有那么绝对哈哈,还是有些情况是只考虑短期利益的,不过不在这里讨论那么全啦,这只是一个DQN的教程,没办法把所有的东西都顾及到呢)!所以我们想要用未来所有reward的期望作为action的价值!那么我们怎么求这个所有未来reward呢?

    最简单的方法是Monte carlo算法,用模拟代替计算,不过这个并不是本文的重点,就不详细介绍啦。还有一种方法呢,是Q learning。Q learning 是 Deep Q learning 的前身。 不过呢, Q learning 并没有试图用未来所有的reward作为action的价值,而是做了一个近似,求了一个q value 。 假如说给了一个当前的状态s_t ,我们认为在这个状态下,采取action a_t,状态转移到s_t'的价值e可以用下面这个公式计算Q(s_t,a_t):=Q(s_t,a)+\alpha[reward+\gamma \cdot max_a Q(s_t',a)-Q(s_t,a_t)]

    好吓人的公式!不过直观解释一下其实挺简单的。前面说了,这个公式是近似求了一个未来所有的reward的和。怎么求的呢。对于任何一个状态s_t我们采用action a_t之后这个状态就会发生转移,并且产生一些reward,这些reward呢是我们受益的一部分,但并不是全部。状态转移到了s_t'之后呢。我们的想法是,如果我们能给这个状态也打个分,给它评个价值,用来表示以后所有的收益和不就行了?那么怎么表示呢?我们的直观理解是,在状态s_t'我们能采取好多好多action,这每个action都有一定价值(qvalue),我们是不是可以把这些action的价值求个和,或者求个平均,或者取个最大值作为这个状态的价值?答案是肯定的。而在Q learning 中,我们采用的是最大值,而不是平均值或者求和,如果用平均值或者求和也是可以做的,只不过是另外一种算法,有另外的特性,并不是本文需要介绍的内容。有些人可能要问了,你说的这个状态的价值还是要用action的价值来表示,好像哪里不对啊。我们不就是要求action的价值吗,你这用另外一个action的价值来求这个action的价值?耍我呢?没错!不是没错耍你呢,而是没错我们就是要用一个action的价值来求另外一个action的价值,这是一种类似于迭代的求法。我们把它叫做Bootstrapping。有很多的强化学习的方法都用到这种bootstrapping。但是这个讲起来或者证明起来都有点困难,有点浪费篇幅,所以我们就不在这里说明了,你只需要知道,这个确实是可以使用的。它的理论依据是Belleman optimization equation。

    另外啊,有个小问题,可能有些人要问,为什么右边不是直接reward加未来reward的估算啊,右边这是什么玩意。其实这个东西变个性你可能就看得懂啦。看下面这个公式,新的qvalue其实是1-\alpha倍的老qvalue加上\alpha倍的reward和未来reward的和。这是在干啥?这是在控制学习的速度呀同志们!我们不想让新来的那么耀武扬威的,老同志们的脸往哪儿放?我不允许,只能让新来的按照跟老同志商量着来,不管是三七分还是八二分,总要让老同志发挥点作用啊你说对不对。Q(s_t,a_t):=(1-\alpha)Q(s_t,a)+\alpha[reward+\gamma \cdot max_a Q(s_t',a)]

    总结一下,有了这个公式,我们就能给每个状态的每个action计算一个qvalue,所有的状态和所有的action求出来的qvalue就组成了一直张表格,这个表格呢就叫Q table。

    这就是Q learning 的主要内容啦, 那么这还有一个问题就是,如果状态很多呢,多到没有什么表格能表示的下呢,至少计算机里是没办法保存的。为了解决这个问题 ,我们很自然的就想到了,神经网络,深度学习。深度学习在背表格方面那可是太专业了。所以说, deep q learning 简单来说就是用深度学习的神经网络背了一张巨大无比的表格q table。复杂来说,DQN就是存了一张大表,然后加上一些骚操作,比如双网络,比如ReplayMemory。

    2. 举个栗子 (CartPole task)

    为了能够更好的说明DQN,我们参考了pytorch那篇官方教程中给出的例子Cartpole problem。这个任务其实挺简单,就是一个小车中间差了一个木棍,我们可以操控小车往左还是往右,我们想通过我们的操作让这个小车多活一会,也就是木棍能够更长时间都保持平衡不至于脸着地。


    Pasted image 20221029180920.png

    按照强化学习的术语来总结一下这个task呢就是:

    • State: 每一个时刻,小车的[速度,角速度,位置,角度]组成了一个向量,每一组值就是一个状态。这是环境的状态,但是我们在模型训练的时候,讨了个巧,用这一帧的图像减去上一帧的图像来表示变换,我们把这个变换作为我们的状态,也是作为我们模型的输入。我们之所以能讨这个巧,是因为,我们这样做得到的状态其实是包含了我们刚刚那四个信息的,甚至还包含了更多的信息。
    • Action: 我们的agent就是小车咯,它能进行的操作其实就两个,向左和向右。
    • Reward function: 我们认为小车每多活一秒都是胜利!所以每一步成功的没死都应该得到+1的奖励,死了就没得奖励+0。
    • Transition function: 这不知道!因为是model-free 嘛,我们没有对环境建模哦。

    3. DQN 细节&代码讲解(Pytorch 官方教程代码)

    本文呢,对pytorch 官网的教程,按照作者自己的习惯进行了简单的重构。官方教程为了方便大家能够在google colab里直接执行,把所有的东西都写到了一个文件里。本文呢,把这个教程分成了三个文件:

    • preprocessing.py。主要用来处理输入数据,把图片处理成我们想要的格式。顺便定义了ReplayMemory这个类,用于存储以前的数据。
    • model.py。 主要用来定义DQN 模型。
    • train.py。主要用来定义训练的过程。

    我们就按照训练过程为主线,来讲解一下整个代码,顺便介绍一下DQN的细节。DQN主要包括这样几个步骤:

    • 准备数据。
    • 准备模型。我们有两个模型,一个叫policy network,一个叫target network。我们前面一直在说我们使用深度学习模型来背那个巨大的q table。其实现实要比这个稍微复杂一点点。
      • 首先是我们怎么背?我们的做法是设计这样一个深度学习模型,它的输入是状态向量,输出是对应的各个action的q value。当然也可以有别的设计,比如吧状态和action同时作为输入,输出也是qvalue等,但是效果不如这个好啦。
      • 其次,我们都知道深度学习是要计算一个loss值,并且进行梯度下降才能学习的。我们现在能够通过以state作为输入,计算出每个action的价值了,那我们怎么知道这个action的价值是不是对的?我们需要有一个golden standard啊?这个golden standard就相当于分类问题里面的标签一样,有它我们才能计算损失值。那么这个golden standard从哪里来呢?从它自己来!记得我们的belleman optimization function吗?我们可以迭代的求啊!记得这个公式吗Q(s_t,a_t):=Q(s_t,a)+\alpha[reward+\gamma \cdot max_a Q(s_t',a)-Q(s_t,a_t)] 我们完全可以用公式右边作为一个golden standard,不过捏,首先那个控制学习速度的东西我们用不着了,因为我们梯度下降的时候也会控制学习率呢,不要重复搞啦就。其次,如果我们每一个输入都用这种方式计算golden standard,那我们的模型可是要疯了!!你的数据样本的随机性那么大,你的golden standard自己也不稳定,那我还要从你这两个东西里学到最佳参数??是不是难为人?是不是不给脸?能不能懂事点稳定点?答案是能,所以在DQN中的做法就是我们用了两个一模一样的模型。一个叫policy network,我们就正常的每一步从输入计算输出,还有一个模型叫target network,这个网络和刚刚那个网络结构一模一样,但是他非常特别,它的参数不是学习来的,是我们从policy network中复制过来的,但是我们是每隔一段时间复制一次,比如50个batch复制一次。并且我们是使用这个target network来计算我们的golden standard的。是不是很迷惑?这样做的好处是什么?其实就是我们刚刚说的,我们的golden standard就可以变得很稳定了啊,它50个batch才会更新一次target network好歹让我们的policy network可以慢慢的学一学你在变。
    • 开始训练
      • 观察state,根据policy选择action,然后执行action。在这个教程中,我们采用的是epsilon-greedy 的policy,其中epsilon值还是在衰减的。这个Policy的选择是自由的,你当然可以换成别的啦。
            action = epsilon_greedy_policy(state, policy_network, n_actions, eps_start, eps_end, eps_decay, steps_done)  
            steps_done += 1  
            _, reward, done, _, _ = env.step(action.item())
    
     - action 执行了之后,环境的状态将会发生改变。我们将原状态,转移后的状态,reward,还有action的信息放在一起看做是一个训练数据,然后把这个训练数据存入到ReplayMemory。ReplayMemory其实就是个列表,存了很多这样的数据,我们每次都是从这个memory里取的以前的数据,一个batch一个batch的取,这样能够降低一点我们数据的随机性,让我们的训练能够更加的稳定。
    ```python
    reward = torch.tensor([reward])  
    last_screen = current_screen  
    current_screen = get_screen(env)  
    if not done:  
        next_state = current_screen - last_screen  
    else:  
        next_state = None  
    memory.push(state, action, next_state, reward)
    
        - 从memory中取出一个batch的数据。
    ```python
        if len(memory) > batch_size:  
            batch_data = memory.sample(batch_size)  
            batch_data = Transition(*zip(*batch_data))  
            state_batch =torch.cat(batch_data.state)  
            reward_batch =torch.cat(batch_data.reward)  
          
            action_batch = torch.cat(batch_data.action)  
            next_state_batch = torch.cat([s for s in batch_data.next_state if s is not None])  
            non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch_data.next_state)), dtype=torch.bool)
    
    ``
    - 用取出的batch数据作为policy network 输入,计算各个action 的价值。利用target network 来计算一个golden standard,再计算huberloss,进行梯度下降
    
        state_action_values = policy_network(state_batch).gather(1, action_batch)  
        next_state_values = torch.zeros(batch_size)  
        next_state_values[non_final_mask] = target_network(next_state_batch).max(1)[0].detach()  
        expected_state_action_values = next_state_values * gamma + reward_batch  
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))  
        optimizer.zero_grad()  
        loss.backward()  
        for param in policy_network.parameters():  
            param.grad.data.clamp_(-1, 1)  
        optimizer.step()
    
    - 如果有必要(隔了50个batch了)就更新一下target network。
    
        if episode_i % target_update == 0:  
            target_network.load_state_dict(policy_network.state_dict())
    

    4. 教程中学到的一些编程骚操作

    • Namedtuple 很好用哦,可以按照名字来存取数据,又可以轻松的得到对应的列表。
    • Huber loss能够更好的处理outlier,不像MSE一样受到outlier影响那么大。
    • Count()函数可以自增计数
    • pytorch的unfold 函数很好玩,可以把一个Tensor一折一折的打开。在本文里用来画了个小图。
    • clamp_()函数可以将梯度限制在某一个范围内。
    • slice()函数可以用来生成slicing 列表,slicing列表可以用来挑选元素。

    5. 最后贴上完整的代码

    preprocessing.py

    """  
    Author: xuqh  
    Created on 2022/10/21  
    """  
    import random  
    from collections import namedtuple, deque  
      
    import numpy as np  
    import torch  
    import torchvision.transforms as T  
    from PIL import Image  
    import gym  
    import matplotlib.pyplot as plt  
    import os  
      
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  
    resize = T.Compose([  
        T.ToPILImage(),  
        T.Resize(40, interpolation=Image.CUBIC),  
        T.ToTensor()  
    ])  
      
      
    def get_cart_location(env):  
        """find the location of the cart"""  
        screen = env.render()  
        _, screen_width, _ = screen.shape  
        world_width = env.x_threshold * 2  
        scale = screen_width / world_width  
        return int(env.state[0] * scale + screen_width / 2.0)  # middle of the cart  
      
      
    def get_screen(env):  
        """get an image of the environment"""  
        screen = env.render().transpose((2, 0, 1))  # [channel, height, width]  
        # get only the cart image (lower half)    _, screen_height, screen_width = screen.shape  
        screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]  # top left corner is the origin  
        view_width = int(screen_width * 0.6)  
        cart_location = get_cart_location(env)  
        if cart_location < view_width // 2:  
            slice_range = slice(view_width)  # only select first half, what if it is close to the center  
        elif cart_location > (screen_width - view_width // 2):  
            slice_range = slice(-view_width, None)  # only select second half  
        else:  
            slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)  # select center half  
      
        screen = screen[:, :, slice_range]  
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255  
        screen = torch.from_numpy(screen)  
        return resize(screen).unsqueeze(0)  # add a batch dimensionÒ  
      
      
    Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))  
      
      
    class ReplayMemory(object):  
        def __init__(self, capacity):  
            self.memory = deque([], maxlen=capacity)  
      
        def push(self, *args):  
            self.memory.append(Transition(*args))  
      
        def sample(self, batch_size):  
            return random.sample(self.memory, batch_size)  
      
        def __len__(self):  
            return len(self.memory)  
      
      
    if __name__ == '__main__':  
        env = gym.make("CartPole-v0", render_mode="rgb_array").unwrapped  
        env.reset()  
        # plt.figure()  
        # plt.imshow(get_screen(env).cpu().squeeze(0).permute(1, 2, 0).numpy(), interpolation="none")    # plt.title("Example extracted screen")    # plt.show()    print(get_screen(env))  
        print(get_screen(env))  
        print(get_screen(env))
    

    model.py

    """  
    Author: xuqh  
    Created on 2022/10/21  
    """  
    import torch.nn as nn  
    import torch.nn.functional as F  
      
      
    class DQN(nn.Module):  
        def __init__(self, h, w, outputs):  
            super(DQN, self).__init__()  
            self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)  
            self.bn1 = nn.BatchNorm2d(16)  
            self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)  
            self.bn2 = nn.BatchNorm2d(32)  
            self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)  
            self.bn3 = nn.BatchNorm2d(32)  
            conv2d_size_out = lambda size, kernel_size=5, stride=2: (size - kernel_size) // stride + 1  
            convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))  
            convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))  
            linear_input_size = convw * convh * 32  
            self.head = nn.Linear(linear_input_size, outputs)  
      
        def forward(self, x):  
            x = F.relu(self.bn1(self.conv1(x)))  
            x = F.relu(self.bn2(self.conv2(x)))  
            x = F.relu(self.bn3(self.conv3(x)))  
            return self.head(x.view(x.size(0), -1))```
    
    **train.py**
    ```python
    """  
    Author: xuqh  
    Created on 2022/10/21  
    """  
    import math  
    import random  
    from itertools import count  
      
    import gym  
    import torch  
    from torch import optim, nn  
      
    from DQN.model import DQN  
    from DQN.preprocessing import ReplayMemory, get_screen, Transition  
    import matplotlib.pyplot as plt  
      
      
    def epsilon_greedy_policy(state, policy_network, n_actions, eps_start, eps_end, eps_decay, steps_done):  
        eps = eps_end + (eps_start - eps_end) * math.exp(-1.0 * steps_done / eps_decay)  
        if random.random() > eps:  
            action = policy_network(state).max(1)[1].view(1, 1)  # [N, n_action,1]  
        else:  
            action = torch.tensor([[random.randrange(n_actions)]], dtype=torch.long)  
        return action  
      
      
    def plot_duration(episode_durations):  
        """plot duration for each episode and average duration"""  
        plt.figure()  
        plt.clf()  
        duration_t = torch.tensor(episode_durations, dtype=torch.float)  
        plt.title("Training")  
        plt.xlabel("Episode")  
        plt.ylabel("Duration")  
        plt.plot(episode_durations)  
        # plot average duration as well  
        if len(episode_durations)>100:  
            means = duration_t.unfold(0, 100, 1).mean(1).view(-1)  
            means = torch.cat((torch.zeros(99), means))  
            plt.plot(means.numpy())  
        plt.pause(0.001)  
      
      
    if __name__ == '__main__':  
        # Initialization  
        n_episodes = 5000  
        gamma = 0.999  
        batch_size = 128  
        eps_start = 0.9  
        eps_end = 0.05  
        eps_decay = 200  
        target_update = 10  
        memory_capacity = 10000  
        steps_done = 0  
        # prepare environment & data  
        env = gym.make("CartPole-v0", render_mode="rgb_array").unwrapped  
        env.reset()  
        memory = ReplayMemory(memory_capacity)  
        init_screen = get_screen(env)  
        _, _, screen_height, screen_width = init_screen.shape  
        n_actions = env.action_space.n  
        # prepare model  
        policy_network = DQN(screen_height, screen_width, n_actions)  
        target_network = DQN(screen_height, screen_width, n_actions)  
        target_network.load_state_dict(policy_network.state_dict())  
        optimizer = optim.RMSprop(policy_network.parameters())  
        criterion = nn.SmoothL1Loss()  
        # start training  
        episode_durations = []  
        for episode_i in range(n_episodes):  
            env.reset()  
            last_screen = get_screen(env)  
            current_screen = get_screen(env)  
            state = current_screen - last_screen  # 0 since the env did not change  
            for t in count():  # 有点东西  
                # take actions and get new transitions            action = epsilon_greedy_policy(state, policy_network, n_actions, eps_start, eps_end, eps_decay, steps_done)  
                steps_done += 1  
                _, reward, done, _, _ = env.step(action.item())  
                reward = torch.tensor([reward])  
                last_screen = current_screen  
                current_screen = get_screen(env)  
                if not done:  
                    next_state = current_screen - last_screen  
                else:  
                    next_state = None  
                memory.push(state, action, next_state, reward)  
                state = next_state  # move to next state, allowing environment to continue  
      
                # acquire data from memory            if len(memory) > batch_size:  
                    batch_data = memory.sample(batch_size)  
                    batch_data = Transition(*zip(*batch_data))  
                    state_batch =torch.cat(batch_data.state)  
                    reward_batch =torch.cat(batch_data.reward)  
      
                    action_batch = torch.cat(batch_data.action)  
                    next_state_batch = torch.cat([s for s in batch_data.next_state if s is not None])  
                    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch_data.next_state)), dtype=torch.bool)  
      
                    # train model  
                    state_action_values = policy_network(state_batch).gather(1, action_batch)  
                    next_state_values = torch.zeros(batch_size)  
                    next_state_values[non_final_mask] = target_network(next_state_batch).max(1)[0].detach()  
                    expected_state_action_values = next_state_values * gamma + reward_batch  
                    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))  
                    optimizer.zero_grad()  
                    loss.backward()  
                    for param in policy_network.parameters():  
                        param.grad.data.clamp_(-1, 1)  
                    optimizer.step()  
      
                    # update target network  
                    if episode_i % target_update == 0:  
                        target_network.load_state_dict(policy_network.state_dict())  
                if done:  
                    break  
            episode_durations.append(t)  
            plot_duration(episode_durations) 
    

    相关文章

      网友评论

          本文标题:DQN入坑教程

          本文链接:https://www.haomeiwen.com/subject/gqoctdtx.html