LSTMs

作者: satyrs_sh | 来源:发表于2017-12-04 05:51 被阅读0次
  • 篇幅稍长,分为四个部分
  1. background
  2. step-by-step
  3. show me the code
  4. deep thinking

codes
https://github.com/satyrswang/blog-jianshu/blob/master/LSTM.lua

background

  • what?
    rnn和feedforward network有嘛不同?

It’s the easiest to implement an RNN just as a feedforward network with some parts of the input feeding into the middle of the stack, and a bunch of outputs coming out from there as well. There is no magic internal state kept in the network. It’s provided as a part of the input!

只是把隐层有拎出来作为下一个隐层的input。=_=
然而,理论支持吗?

  • Problem

    • 视频那么多帧,前一帧连着后一帧,间隔又短,那么是否可用前一帧来预测后一帧?
      看情况。
    • 完形填空 the clouds are in the ___
      I grew up in France… I speak fluent French.
      当gap变大,France和French距离那么远,RNN没用了。
    • 为什么gap大了,就没用了?理论证明如下:
      Bengio, et al. (1994)
      Hochreiter (1991) German
  • 然而,LSTMs 不会因为gap惹事儿。
    Long Short Term Memory networks


    LSTM3-chain.png

step-by-step

  • 图LSTM3-chain.png第二个干嘛了?你先别看图,听我讲:
    注意这里横着看,看的是chain中第t个
#过程1 --名字是 input transform
输入 :input中的x(t),chain中前一个x输出的结果h(t-1)
参数 :x的权重w1,h的权重w2,加一个bias
激活函数 :tanh

以上得到一个结果记为c_in

#过程2 --名字是 三个gates,每个gate如下
输入 :input中的x(t),chain中前一个x输出的结果h(t-1)
参数 :x的权重w1,h的权重w2,加一个bias
激活函数 :g

得到三个结果记为i , f , o
先保留一个问题: 过程1、2的输入虽然都是x h变量,但是是一样的吗?还是x h这两个向量的部分值呢?
有了c_in,i , f , o 之后干嘛,我怎么得到这一层的h?

#过程3 -- 名字是 state update 
输入 :c_in,i , f , o ,c_out(t-1)
输出 :新的h(t), c_out(t)

c_out(t-1) 是chain中前一个的输出呗。h、c_out怎么计算的?
c_out(t) = f * c_out(t-1) + i * c_in
h(t) = o * tanh(c_out(t))

  • 就这么简单?
    是的。为什么能这样呢?

Because of the gating mechanism the cell can keep a piece of information for long periods of time during work and protect the gradient inside the cell from harmful changes during the training. Vanilla LSTMs don’t have a forget gate and add unchanged cell state during the update (it can be seen as a recurrent connection with a constant weight of 1), what is often referred to as a Constant Error Carousel (CEC). It’s called like that, because it solves a serious RNN training problem of vanishing and exploding gradients, which in turn makes it possible to learn long-term relationships.

原来,因为有个gating mechanism 就是 过程2 嘛,解决了RNN的gradient的问题。为什么能解决vanishing and exploding gradients的问题呢?理论支持去看论文。

show me the code

基于 Torch7

  • snippet1: inputs
local inputs = {}
table.insert(inputs, nn.Identity()())   -- x(t)
table.insert(inputs, nn.Identity()())   -- c_out(t-1)
table.insert(inputs, nn.Identity()())   -- h(t-1)
local input = inputs[1]
local prev_c = inputs[2]
local prev_h = inputs[3]
  1. 想想看我们要什么?你回答完了之后,听我讲:
    三个变量 :过程1、2要的x(t) h(t-1)和过程3还要的c_out(t-1)

  2. 怎么得到?
    这里用到了nn.Identity()()table.insert

The array-like objects in lua are called tables.
nn.Identity() - passes on the input (used as a placeholder for input)

如果你用tf,那么nn.Identify就是placeholder

  • snippet2: Computing gate values
local i2h = nn.Linear(input_size, 4 * rnn_size)(input) 
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)   
local preactivations = nn.CAddTable()({i2h, h2h})    
  1. 4 * rnn_size什么鬼?
    过程1、2在激活前是不是都是x(t) h(t-1)的线性变换?即nn.Linear
    preactivations将i2h, h2h作加法运算返回一个vector。
    我们将线性变换的结果分成4份,每份rnn_size多个值。为什么分为4份?记得我们有三个gates吗 ,得到i,f,o?

The first will be used for in gates, second for forget gates, third for out gates and the last one as a cell input .

就跟玩儿似的。这里as a cell input就是直赋值给了h(t),作为chain下一个的输入。也解释了之前的保留问题,即输入并不是一样的。

local pre_sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(preactivations)
local all_gates = nn.Sigmoid()(pre_sigmoid_chunk)
local in_chunk = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(preactivations)
local in_transform = nn.Tanh()(in_chunk)
local in_gate = nn.Narrow(2, 1, rnn_size)(all_gates)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(all_gates)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(all_gates)
  1. nn.Narrow什么鬼?

select appropriate parts of the preactivation vector.

  1. 其他很简单啊,前3份传入gates要nn.Sigmoid()激活。3另一份只需要nn.Tanh()激活。
  • snippet3: Cell and hidden state
    gates结果i f o也有了。进入过程3了。
local c_forget = nn.CMulTable()({forget_gate, prev_c})
local c_input = nn.CMulTable()({in_gate, in_transform})
local next_c = nn.CAddTable()({ c_forget, c_input})
local c_transform = nn.Tanh()(next_c)
local next_h = nn.CMulTable()({out_gate, c_transform})

按公式计算。没说的。得到next_cnext_h

  • snippet4: define module
outputs = {}
table.insert(outputs, next_c)
table.insert(outputs, next_h)
return nn.gModule(inputs, outputs)
  • 手残党的snippet5: 栗子
require 'nn'
require 'nngraph'
LSTM = require 'LSTM.lua'  --以上snippet
--创建3层LSTM,输入3输出3
network = {LSTM.create(3, 4), LSTM.create(4, 4), LSTM.create(4, 3)}
--准备
local x = torch.randn(1, 3)
local previous_state = {
  {torch.zeros(1, 4), torch.zeros(1,4)},
  {torch.zeros(1, 4), torch.zeros(1,4)},
  {torch.zeros(1, 3), torch.zeros(1,3)}
}
output = nil
next_state = {}

--feed数据
local layer_input = {x, table.unpack(previous_state[1])}
for l = 1, #network do
  local layer_output = network[l]:forward(layer_input)
  table.insert(next_state, layer_output)
  local layer_h = layer_output[2]
  if l < #network then
    layer_input = {layer_h, table.unpack(previous_state[l + 1])}
  else
    output = layer_h
  end
end

print(next_state)
print(output)

deep thinking

尽管已经很长了。还是要写理解。这时你可以看图了。

what information we’re going to throw away from the cell state
what new information we’re going to store in the cell state

  1. 什么是forget gate?
    • 其实就是将x h线性变换后做一个sigmoid, 如果结果是0,代表forget c_out(t-1)。
    • 这个例子非常好:

the cell state might include the gender of the present subject, so that the correct pronouns can be used. When we see a new subject, we want to forget the gender of the old subject.

  1. i 和 c_in?
    • 两步,第一步i,i = 1相当于是确定哪些值我们需要update或者说需要更新输入的多大成分,想象为将c_in scale了i倍;而tanh相当于为需要更新的值确定了更新成什么c_in。
    • 相乘,则确定了新的候选值,再与f相加,我们便确定了新的状态。

we’d want to add the gender of the new subject to the cell state, to replace the old one we’re forgetting.

  1. 那么输出什么?

    • 首先我们需要确定哪些更新后的状态需要输出,用sigmoid,得到的o就是我们想要输出的部分。
    • 然后 基于更新好的状态c_out(t),将其tanh控制在[-1,1]之间。乘以o,输出我们要输出的。

since it just saw a subject, it might want to output information relevant to a verb, in case that’s what is coming next. For example, it might output whether the subject is singular or plural, so that we know what form a verb should be conjugated into if that’s what follows next.

  1. 各类变种
    Gers & Schmidhuber (2000)
    Cho, et al. (2014)
    Yao, et al. (2015)
    Koutnik, et al. (2014)

  2. 比对各类变种的结论
    Greff, et al. (2015)
    Jozefowicz, et al. (2015)

欢迎补充材料。
reference:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/

相关文章

网友评论

      本文标题:LSTMs

      本文链接:https://www.haomeiwen.com/subject/wwbzbxtx.html