美文网首页
对话系统-PG模型及训练方式

对话系统-PG模型及训练方式

作者: 又双叒叕苟了一天 | 来源:发表于2018-11-18 16:02 被阅读0次

    原文:https://arxiv.org/abs/1805.09461?context=stat.ML

    Policy Grandient(PG)

    ​ 在强化学习中,代理通过特定的策略采取行为,不同应用的策略都不同。例如:在文本概括任务中,策略就是语言模型 p(y|X),给出 X 输出 y

    ​ 现在假设我们的代理是个 RNN,它采取不连续的行为。RNN 的输出层通常使用 softmax 函数,并从该层产生输出。

    ​ 在 Teacher Forcing 的训练方法中,我们根据策略采取行为。在我们序列结束或看到句子结束标志(EOS)时候,我们比较策略产生的行为(\hat{y_t})和真实行为的值(y_t),基于评价标准就能观察到一个奖励

    ​ 我们的目标就是找到参数以最大化期望的奖励,于是我们把损失函数定义为奖励期望的负数:
    \mathcal{L}_\theta=-\mathbb{E}_{\hat{y}_1,\cdots,\hat{y}_T\sim\pi_\theta(\hat{y}_1,\cdots,\hat{y}_T)}[r(\hat{y}_1,\cdots,\hat{y}_T)]\tag{12}
    \hat{y}_tt 时刻的行为,r(\hat{y}_1,\cdots,\hat{y}_T) 为这个采样序列的奖励。

    实际上,我们只用单个样本去估计了这个期望,所以上面损失函数的导数变成下面这样:
    \nabla_\theta \mathcal{L}_\theta=-\underset{\hat{y}_{1\cdots T}\sim\pi_\theta}{\mathbb{E}}[\nabla_\theta\log{\pi_\theta(\hat{y}_{1\cdots T})r(\hat{y}_{1\cdots T})}]\tag{13}
    根据链式法则,我们写成这样:
    \nabla_\theta \mathcal{L}_\theta=\frac{\partial \mathcal{L}_\theta}{\partial\theta}=\sum_t\frac{\partial \mathcal{L}_\theta}{\partial o_t}\frac{\partial o_t}{\partial\theta}\tag{14}
    其中 o_t 是 softmax 函数的输入。所以,损失 \mathcal{L}_\theta 对于 o_t 的梯度:
    \frac{\partial \mathcal{L}_\theta}{\partial o_t}=\bigg(\pi_\theta(y_t|\hat{y}_{t-1},s_t,c_{t-1})-1(\hat{y}_t)\bigg)(r(\hat{y}_1,\cdots,\hat{y}_T)-r_b)\tag{15}
    其中 1(\hat{y}_t)\hat{y}_t 的 1-of-|\mathcal{A}| 表示,r_b 是一个基线的奖励,它可以是任何值,因为它不依赖与RNN的参数。

    ​ 这个式子类似于多分类逻辑回归的梯度。在逻辑回归中,交叉熵梯度是预测和实际输出 1-of-|\mathcal{A}| 表示之间的差:
    \frac{\partial \mathcal{L}_\theta^{CE}}{\partial o_t}=\pi_\theta(y_t|y_{t-1},s_t,c_{t-1})-1(y_t)\tag{16}
    我们注意到(15)中,我们用到了模型产生的输出,而(16)全是用真实值去计算梯度。

    基线奖励(r_b)的目的是使模型采取奖励 r\gt r_b 的行为,而不鼓励 r\lt r_b 的行为。由于我们只使用一个样本去计算损失的梯度,这个基线还将减少梯度估计量的方差。如果基线不依赖于模型的参数 \theta ,(15)将会是损失梯度的无偏估计量。下面证明一下添加基线奖励 r_b 对于损失的期望没有任何影响
    \begin{align} \mathbb{E}_{\hat{y}_1\cdots T\sim\pi_\theta}[\nabla_\theta\log{\pi_\theta(\hat{y}_{1\cdots T})r_b}]&=\\ r_b\sum_{\hat{y}_{1\cdots T}}\nabla_\theta\pi_\theta(\hat{y}_{1\cdots T})&=\\ r_b\nabla_\theta\sum_{\hat{y}_{1\cdots T}}\pi_\theta(\hat{y}_{1\cdots T})&=\\ r_b\nabla_\theta1&=0 \end{align} \tag{17}
    ​ 这个算法叫做 REINFORCE,它是一个解决 seq2seq 问题的策略梯度算法。这个模型的一个问题就是每个时间步只使用一个样本来训练模型,所以模型会存在高方差的问题。为了缓解该问题,我们每次训练都用 N 个行为序列的样本,通过平均 N 个序列来更新梯度:
    \mathcal{L}_\theta=\frac1N\sum_{i=1}^N\sum_t\log{\pi_\theta(y_{i,t}|\hat{y}_{i,t-1},s_{i,t},c_{i,t-1})}\times(r(\hat{y}_{i,1},\cdots,\hat{y}_{i,T})-r_b)\tag{18}
    其中 r_b=\frac1N\sum_{i=1}^{N}r(\hat{y}_{i,1},\cdots,\hat{y}_{i,T})

    ​ 解决模型高方差问题的另一个方法就是 SC(Self-Critic)模型。它没有使用样本去估计基线,而是使用模型推导阶段通过贪心搜索获得的输出作为基线。因此,我们使用模型的采样输出作为 \hat{y}_t,并使用最终输出分布的贪心选择获得 \hat{y}_t^g,上标 g 表示贪心选择。通过这种方式,REINFORCE 模型的损失为:
    \mathcal{L}_\theta=\frac1N\sum_{i=1}^N\sum_t\log\pi_\theta(y_{i,t}|\hat{y}_{i,t-1},s_{i,t},c_{i,t-1})\times\bigg(r(\hat{y}_{i,1},\cdots,\hat{y}_{i,T})-r(\hat{y}_{i,1}^g,\cdots,\hat{y}_{i,T}^g)\bigg)\tag{19}
    REINFORCE 算法的训练、测试步骤

    输入:输入序列 X,真实值序列 Y,并且最好有个与训练的策略 \pi_\theta

    输出:采用 REINFORCE 训练过的策略。

    训练步骤

    while 尚未收敛 do

    1. 选择 batch = N 个输入序列 X 和输出序列 Y
    2. 采样 N 个完整的行为序列:\{\hat{y}_1,\cdots,\hat{y}_T\sim\pi_\theta(\hat{y}_1,\cdots,\hat{y}_T)\}_1^N
    3. 观察序列的奖励并且计算基线奖励 r_b
    4. 根据(18)计算损失。
    5. 更新网络参数:\theta\leftarrow\theta+\alpha\nabla_\theta\mathcal{L}_\theta

    end while

    测试步骤

    for batch 个输入和输出序列 XY do

    1. 使用训练过的模型和 \hat{y}_{t'}=\underset{y}{\arg \max}\pi_\theta(y|\hat{y}_t,s_{t'}) 采样输出 \hat{Y}
    2. 使用性能度量指标,例如:ROUGE_l 来评估模型。

    end for

    reinforce算法

    这个方法的第二个问题就是,我们在行为采样完成后才能看到奖励,这种特性在大多数 seq2seq 模型中都不是很好。如果能中途看到部分奖励,而且奖励不太好,我们就能够去选择更好的行为了。但是 REINFORCE 做不到这点。所以,这个模型经常产生很差的结果,并且要花更多时间去收敛。这个问题在模型刚随机初始化的时候最严重。为了缓解这个问题,Ranzato et al.建议采用交叉熵预训练,再逐渐替换成REINFORCE损失

    另一个解决 REINFORCE 算法高方差的方法就是使用重要性抽样(importance sampling )。其基本思想就是从旧模型而不是当前模型采样序列来计算损失。

    相关文章

      网友评论

          本文标题:对话系统-PG模型及训练方式

          本文链接:https://www.haomeiwen.com/subject/bjsgfqtx.html