美文网首页
deep_learning 02. tf.nn.rnn_cell

deep_learning 02. tf.nn.rnn_cell

作者: adowu | 来源:发表于2019-03-04 11:09 被阅读0次

    开始的话:
    从基础做起,不断学习,坚持不懈,加油。
    一位爱生活爱技术来自火星的程序汪

    上一节讲到了最基础的BasicRNNCell,本章就简单介绍下BasicLSTMCell。如果有不对的地方还请指正,谢谢!

    话不多说,先上图:


    basicLSTMCell.png

    这张图大家肯定看到过很多次,是一个展开的LSTM Cell的内部结构。接下来还是和上一节一样,从tensorflow代码层面分析下。

    代码和上一节的BasicRNNCell 都差不多,只是调用的rnn_cell变了。

    def basic_lstm_demo():
        cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=4)
        zero_state = cell.zero_state(batch_size=2, dtype=tf.float32)
        a = tf.random_normal([2, 3, 4])
        out, state = tf.nn.dynamic_rnn(
            cell=cell,
            initial_state=zero_state,
            inputs=a
        )
    

    tensorflow中的主要逻辑代码如下:

      def call(self, inputs, state):
        """Long short-term memory cell (LSTM)."""
        sigmoid = math_ops.sigmoid
        one = constant_op.constant(1, dtype=dtypes.int32)
        # Parameters of gates are concatenated into one multiply for efficiency.
        if self._state_is_tuple:
          c, h = state
        else:
          c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
    
        gate_inputs = math_ops.matmul(
            array_ops.concat([inputs, h], 1), self._kernel)
        gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
    
        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = array_ops.split(
            value=gate_inputs, num_or_size_splits=4, axis=one)
    
        forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
        # Note that using `add` and `multiply` instead of `+` and `*` gives a
        # performance improvement. So using those at the cost of readability.
        add = math_ops.add
        multiply = math_ops.multiply
        new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
                    multiply(sigmoid(i), self._activation(j)))
        new_h = multiply(self._activation(new_c), sigmoid(o))
    
        if self._state_is_tuple:
          new_state = LSTMStateTuple(new_c, new_h)
        else:
          new_state = array_ops.concat([new_c, new_h], 1)
        return new_h, new_state
    

    接下来就结合着这段源码简要分析下:

            结合着LSTM的图示来理解代码更清楚。
    
            #   输入的inputs [2,3,4],经过unstack则为 list([2,4]).size为3,所以输入到LSTM中的input为[2,4]
            #   初始化的 c 和 h 都是zero_state 也就是都为[2,4]的zero,这是参数state_is_tuple的情况下,
            #   如果这个参数为 False,则 c,h = [2,2]
            c, h = state
            #   初始化权重参数为:在此处就是 [4 + 4, 4 * 4] = [8, 16],为什么乘以4后面就可以知道原因
            kernel_ = [input_dims +num_units, 4 * num_units]
    
            #   concat[inputs, h] = [2, 8] kernel_ = [8, 16], bias=zero of [4 * num_units]
            #   所以gate_inputs = [2, 16]
            gate_inputs = bias_add(matmul(concat([inputs, h], axis=1), kernel_), bias)
    
            #   i 表示input_gate
            #   j 表示new_input
            #   f 表示forget_gate
            #   o 表示output_gate
            #   为了保持维度正确,所以前面要在num_units上乘以4的原因
            i, j, f, o = array_ops.split(value=gate_inputs, num_or_size_splits=4, axis=1)
    
            forget_bias = 1.0
    
            #   计算这个cell中的new_c 和 new_h
            #   forget_gate_output =  sigmoid(add(f, forget_bias_tensor))
            #   input_gate_output = multiply(sigmoid(i), tanh(j))
            #   update_c = add(multiply(c, forget_gate_output), input_gate_output)
            #   output_gate_output = multiply(tanh(new_c), sigmoid(o))
    
            new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),multiply(sigmoid(i), tanh(j)))
            new_h = multiply(tanh(new_c), sigmoid(o))
    

    最后的output输出为:包含了shape为[2,3,4]的每个时间步的输出,以及最后一个cell的输出,这个又包含了c 和 h,shape分别为[2,4]

    每一个时间步的输出:shape 为 [2,3,4]
        tf.Tensor(
            [[[ 0.29594404 -0.06257749  0.00272913  0.38393494]
              [ 0.12317018 -0.10669467  0.21305212 -0.0534559 ]
              [ 0.11735746 -0.03012969  0.08865868 -0.10764799]]
            
             [[-0.07051807  0.02736617  0.07237878 -0.19151129]
              [-0.07522646  0.00569247 -0.01109379 -0.00774325]
              [ 0.05763769 -0.00310471  0.21375947 -0.16625713]]], shape=(2, 3, 4), dtype=float32)
      
        最后一个时间步的输出,包括c 和 h shape 都为 [2,4]
        LSTMStateTuple(
                c=<tf.Tensor: id=309, shape=(2, 4), dtype=float32, numpy=
                array([[ 0.26399267, -0.09096628,  0.1642536 , -0.30149382],
                       [ 0.2447102 , -0.00411555,  0.38746575, -0.21990177]],
                      dtype=float32)>, 
                h=<tf.Tensor: id=312, shape=(2, 4), dtype=float32, numpy=
                array([[ 0.11735746, -0.03012969,  0.08865868, -0.10764799],
                        [ 0.05763769, -0.00310471,  0.21375947, -0.16625713]],
                        dtype=float32)>
                )
    

    这里额外介绍两种LSTM的变体:

    Peephole Connection:也就是让每一个门中都加入细胞状态c


    peephole_connection_lstm.png

    Coupled: forget_gate 和 input_gate 的sigmoid值是相关的,


    coupled_lstm.png

    BasicLSTMCell 是已经要deprecated的接口,更多的变体在接口tf.nn.rnn_cell.LSTMCell()中可以见到。

    我们了解了最基础的BasicLSTMCell,其他变体无非就是在计算方式上做了一些改变,学习起来就很简单了。

    更多代码请移步我的个人github,会不定期更新各种框架。
    本章代码见code
    欢迎关注

    相关文章

      网友评论

          本文标题:deep_learning 02. tf.nn.rnn_cell

          本文链接:https://www.haomeiwen.com/subject/qfybuqtx.html