本文是自己对于lstm的理解的总结,但是最好的文章帮助理解lstm一定是这篇Understanding LSTM Networks
lstm理论
我们平时在做决策的时候,会基于当前的情况和自己平时的经验来做出判断,这就会涉及到一个很重要的概念:经验,对于计算机来说,经验要怎么获取到呢?计算机有CPU进行运算,有内存进行存储,很自然的我们的想法就是将之前的输入存储到内存中,作为我们今后决策的一个经验
我们想下FC(fully connection)网络,其输出结果就只跟当前输入有关,没有用到“经验”,那对于一些需要用到之前输入的场景就会显的力不从心,
因此我们一个很自然的想法就是将之前输入记住的,但是记住多少呢?我们人的记忆力有限,对于一些久远的、不重要的记忆自然就会忘了,
而对于我们决策总是起作用的事情就会记得很清楚,因此我们希望我们的模型也能有这个特性,记住重要的经验,遗忘不重要的
基于上面“经验”的假说,早期就有了RNN(recurrent neural network)了,其主要思想就是将单一的FC重复多份,串联起来,这样当前的输出就不止由当前输入决定,还会受前面输入的影响,其示意图如下:
来自Understanding LSTM Networks但是很快人们就发现这种模型会有两个主要的问题:vanishing gradient 和 exploding gradient,梯度消失和梯度爆炸,一个直观的解释是:看上图x(t)会由x(0)乘以每个子单元之间的一个关系得到,我们用Whh表示,其特征就是 x(t) = x(0) * Whh^t,当我们Whh小于1和大于1的时候,就会分别出现梯度消失和梯度爆炸了。
于是后来就发明了LSTM,其思想就是我们之前说的“经验”,并在此基础上加上了“遗忘”,看下结构图:
图片来自:零基础入门深度学习 (6) - 长短时记忆网络 (LSTM)
我们有3个开关,分别控制着我们的经验选择什么进行遗忘,然后当前的输入我们选择什么进行记忆,以及根据经验我们怎么做出判断,更具体的讲解可以阅读零基础入门深度学习 (6) - 长短时记忆网络 (LSTM)
理论讲了这么多,我们最重要的还是要用,下面我们就结合一些具体的例子来看下怎么使用RNN来解决实际问题
具体的例子可以查看github地址:LSTM的理解
网友评论