- 篇幅稍长,分为四个部分
- background
- step-by-step
- show me the code
- deep thinking
codes
https://github.com/satyrswang/blog-jianshu/blob/master/LSTM.lua
background
- what?
rnn和feedforward network有嘛不同?
It’s the easiest to implement an RNN just as a feedforward network with some parts of the input feeding into the middle of the stack, and a bunch of outputs coming out from there as well. There is no magic internal state kept in the network. It’s provided as a part of the input!
只是把隐层有拎出来作为下一个隐层的input。=_=
然而,理论支持吗?
-
Problem
- 视频那么多帧,前一帧连着后一帧,间隔又短,那么是否可用前一帧来预测后一帧?
看情况。 - 完形填空 the clouds are in the ___
I grew up in France… I speak fluent French.
当gap变大,France和French距离那么远,RNN没用了。 - 为什么gap大了,就没用了?理论证明如下:
Bengio, et al. (1994)
Hochreiter (1991) German
- 视频那么多帧,前一帧连着后一帧,间隔又短,那么是否可用前一帧来预测后一帧?
-
然而,LSTMs 不会因为gap惹事儿。
Long Short Term Memory networks
LSTM3-chain.png
step-by-step
- 图LSTM3-chain.png第二个干嘛了?你先别看图,听我讲:
注意这里横着看,看的是chain中第t个
#过程1 --名字是 input transform
输入 :input中的x(t),chain中前一个x输出的结果h(t-1)
参数 :x的权重w1,h的权重w2,加一个bias
激活函数 :tanh
以上得到一个结果记为c_in
#过程2 --名字是 三个gates,每个gate如下
输入 :input中的x(t),chain中前一个x输出的结果h(t-1)
参数 :x的权重w1,h的权重w2,加一个bias
激活函数 :g
得到三个结果记为i , f , o
先保留一个问题: 过程1、2的输入虽然都是x h变量,但是是一样的吗?还是x h这两个向量的部分值呢?
有了c_in,i , f , o 之后干嘛,我怎么得到这一层的h?
#过程3 -- 名字是 state update
输入 :c_in,i , f , o ,c_out(t-1)
输出 :新的h(t), c_out(t)
c_out(t-1) 是chain中前一个的输出呗。h、c_out怎么计算的?
c_out(t) = f * c_out(t-1) + i * c_in
h(t) = o * tanh(c_out(t))
- 就这么简单?
是的。为什么能这样呢?
Because of the gating mechanism the cell can keep a piece of information for long periods of time during work and protect the gradient inside the cell from harmful changes during the training. Vanilla LSTMs don’t have a forget gate and add unchanged cell state during the update (it can be seen as a recurrent connection with a constant weight of 1), what is often referred to as a Constant Error Carousel (CEC). It’s called like that, because it solves a serious RNN training problem of vanishing and exploding gradients, which in turn makes it possible to learn long-term relationships.
原来,因为有个gating mechanism 就是 过程2 嘛,解决了RNN的gradient的问题。为什么能解决vanishing and exploding gradients的问题呢?理论支持去看论文。
show me the code
基于 Torch7
- snippet1: inputs
local inputs = {}
table.insert(inputs, nn.Identity()()) -- x(t)
table.insert(inputs, nn.Identity()()) -- c_out(t-1)
table.insert(inputs, nn.Identity()()) -- h(t-1)
local input = inputs[1]
local prev_c = inputs[2]
local prev_h = inputs[3]
-
想想看我们要什么?你回答完了之后,听我讲:
三个变量 :过程1、2要的x(t) h(t-1)和过程3还要的c_out(t-1) -
怎么得到?
这里用到了nn.Identity()()
和table.insert
The array-like objects in lua are called tables.
nn.Identity() - passes on the input (used as a placeholder for input)
如果你用tf,那么nn.Identify就是placeholder
- snippet2: Computing gate values
local i2h = nn.Linear(input_size, 4 * rnn_size)(input)
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
local preactivations = nn.CAddTable()({i2h, h2h})
-
4 * rnn_size
什么鬼?
过程1、2在激活前是不是都是x(t) h(t-1)的线性变换?即nn.Linear
。
preactivations
将i2h, h2h作加法运算返回一个vector。
我们将线性变换的结果分成4份,每份rnn_size
多个值。为什么分为4份?记得我们有三个gates吗 ,得到i,f,o?
The first will be used for in gates, second for forget gates, third for out gates and the last one as a cell input .
就跟玩儿似的。这里as a cell input就是直赋值给了h(t),作为chain下一个的输入。也解释了之前的保留问题,即输入并不是一样的。
local pre_sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(preactivations)
local all_gates = nn.Sigmoid()(pre_sigmoid_chunk)
local in_chunk = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(preactivations)
local in_transform = nn.Tanh()(in_chunk)
local in_gate = nn.Narrow(2, 1, rnn_size)(all_gates)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(all_gates)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(all_gates)
-
nn.Narrow
什么鬼?
select appropriate parts of the preactivation vector.
- 其他很简单啊,前3份传入gates要
nn.Sigmoid()
激活。3另一份只需要nn.Tanh()
激活。
- snippet3: Cell and hidden state
gates结果i f o也有了。进入过程3了。
local c_forget = nn.CMulTable()({forget_gate, prev_c})
local c_input = nn.CMulTable()({in_gate, in_transform})
local next_c = nn.CAddTable()({ c_forget, c_input})
local c_transform = nn.Tanh()(next_c)
local next_h = nn.CMulTable()({out_gate, c_transform})
按公式计算。没说的。得到next_c
和next_h
- snippet4: define module
outputs = {}
table.insert(outputs, next_c)
table.insert(outputs, next_h)
return nn.gModule(inputs, outputs)
- 手残党的snippet5: 栗子
require 'nn'
require 'nngraph'
LSTM = require 'LSTM.lua' --以上snippet
--创建3层LSTM,输入3输出3
network = {LSTM.create(3, 4), LSTM.create(4, 4), LSTM.create(4, 3)}
--准备
local x = torch.randn(1, 3)
local previous_state = {
{torch.zeros(1, 4), torch.zeros(1,4)},
{torch.zeros(1, 4), torch.zeros(1,4)},
{torch.zeros(1, 3), torch.zeros(1,3)}
}
output = nil
next_state = {}
--feed数据
local layer_input = {x, table.unpack(previous_state[1])}
for l = 1, #network do
local layer_output = network[l]:forward(layer_input)
table.insert(next_state, layer_output)
local layer_h = layer_output[2]
if l < #network then
layer_input = {layer_h, table.unpack(previous_state[l + 1])}
else
output = layer_h
end
end
print(next_state)
print(output)
deep thinking
尽管已经很长了。还是要写理解。这时你可以看图了。
what information we’re going to throw away from the cell state
what new information we’re going to store in the cell state
- 什么是forget gate?
- 其实就是将x h线性变换后做一个sigmoid, 如果结果是0,代表forget c_out(t-1)。
- 这个例子非常好:
the cell state might include the gender of the present subject, so that the correct pronouns can be used. When we see a new subject, we want to forget the gender of the old subject.
- i 和 c_in?
- 两步,第一步i,i = 1相当于是确定哪些值我们需要update或者说需要更新输入的多大成分,想象为将c_in scale了i倍;而tanh相当于为需要更新的值确定了更新成什么c_in。
- 相乘,则确定了新的候选值,再与f相加,我们便确定了新的状态。
we’d want to add the gender of the new subject to the cell state, to replace the old one we’re forgetting.
-
那么输出什么?
- 首先我们需要确定哪些更新后的状态需要输出,用sigmoid,得到的o就是我们想要输出的部分。
- 然后 基于更新好的状态c_out(t),将其tanh控制在[-1,1]之间。乘以o,输出我们要输出的。
since it just saw a subject, it might want to output information relevant to a verb, in case that’s what is coming next. For example, it might output whether the subject is singular or plural, so that we know what form a verb should be conjugated into if that’s what follows next.
欢迎补充材料。
reference:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
网友评论