梯度消失与爆炸
假设一个只有 3 个输入数据的序列,此时我们的隐藏层 h1、h2、h3 和输出 y1、y2、y3 的计算公式:
RNN 在时刻 t 的损失函数为 Lt,总的损失函数为
t = 3 时刻的损失函数 L3 对于网络参数 U、W、V 的梯度如下:
其实主要就是因为:
- 对V求偏导时,是常数
- 对U求偏导时:
-
里有U,所以要继续对h3应用
chain rule
-
里的是常数,但是里又有U,继续
chain rule
- 以此类推,直到
-
里有U,所以要继续对h3应用
- 对W求偏导时一样
所以:
- 参数矩阵 V (对应输出 ) 的梯度很显然并没有长期依赖
- U和V显然就是连乘()后累加()
其中的连乘项就是导致 RNN 出现梯度消失与梯度爆炸的罪魁祸首,连乘项可以如下变换:
tanh' 表示 tanh 的导数,可以看到 RNN 求梯度的时候,实际上用到了 (tanh' × W) 的连乘。当 (tanh' × W) > 1 时,多次连乘容易导致梯度爆炸;当 (tanh' × W) < 1 时,多次连乘容易导致梯度消失。
网友评论