美文网首页
机器学习面试题-为啥LSTM比RNN好

机器学习面试题-为啥LSTM比RNN好

作者: 无所不知的大刘 | 来源:发表于2020-03-25 23:14 被阅读0次

问题引入

其实这算是个经典的问题了,在一般的只要你做过时间序列相关的项目或者理论的时候,LSTM和RNN的对比肯定是要问的。那两者到底有啥区别呢?

问题回答

其实对于这个问题,要从RNN在发展过程中带来的令人诟病的短处说起,RNN在train参数的时候,需要反向传播梯度,这个梯度是这么算的:
w^{i+1}=w^{i}-r\cdot\frac{\partial{Loss }}{\partial{w}}|_{w:w^{i}},r>0
其中r是学习率,\frac{\partial{Loss }}{\partial{w}}|_{w:w^{i}}是损失函数在w处的导数,针对RNN在结构上很深的特征,会产生梯度消失和梯度爆炸,其中需要了解下什么是梯度消失和梯度爆炸,梯度消失指的是,RNN在某些w^i取值上,导致梯度很小,梯度爆炸指的是,w^i在某些取值上,导致梯度特别大。如果你的学习率r不变的话,那么参数要么几乎不变,要么就是变化剧烈,到时迭代动荡很难手收敛。通过我们对RNN的网络结构的建模,我们发现他的梯度是这个样子的:
\frac{\partial{L_{t}}}{\partial{W^{h}}}=\sum_{t=0}^{T}{\sum_{k=0}^{t}{ \frac{\partial{L_{t}}}{\partial{y_t}} \frac{\partial{y_{t}}}{\partial{h_t}} (\prod_{j=k+1}^{t} \frac{\partial{h_{j}}}{\partial{h_{j-1}}} ) \frac{\partial{h_{k}}}{\partial{W^h}}}}
我们先不管这一大串公式是啥意思,大值得意思就是上面公式里面有依赖于时间t的连乘符号;修正t时刻的误差需要考虑之前的所有时间k的隐藏层对时间t的影响,当kt距离越远,对应着隐含层之间的连乘次数就越多。就是这个连乘的结构产生了梯度消失,梯度爆炸也是它导致的。具体大一大波公式有需要看的话可以看下参考中的地(我只是搬运工)。
而LSTM(长短时记忆网络),因为可以通过阀门(gate,其实就是概率,共有输出、遗忘、输入三个阀门)记忆一些长期信息,所以,相比RNN,保留了更多长期信息(相应地也就保留了更多的梯度)。隐层之间的输入输出可以表示为:
c_{j}=\sigma(W^fX_{j}+b^f)c_{j-1}+\sigma({W^iX_{j}}+b^i)\sigma(WX_{j}+b),于是连乘的项可以表示为:
\frac{\partial{c_{j}}}{\partial{c_{j-1}}}=\sigma(W^fX_{j}+b)
该值得范围在0-1之间,在参数更新的过程中,可以通过控制bais较大来控制梯度保持在1,即使通过多次的连乘操作,梯度也不会下降到消失的状态。所以,相比RNN,在LSTM上,梯度消失问题得到了一定程度的缓解。

更多内容,查看如下(百面机器学习):
[图片上传失败...(image-19a36d-1585149251504)]

https://www.zhihu.com/question/44895610/answer/616818627
https://zhuanlan.zhihu.com/p/30844905
https://blog.csdn.net/laolu1573/article/details/77470889

相关文章

网友评论

      本文标题:机器学习面试题-为啥LSTM比RNN好

      本文链接:https://www.haomeiwen.com/subject/xfbfuhtx.html