美文网首页
Keras RNN 源码分析

Keras RNN 源码分析

作者: 走在成长的道路上 | 来源:发表于2018-05-18 12:06 被阅读0次

    在 keras 源码中, layers/recurrent.py 中看到 RNN 实现方式

    RNN 中的循环体使用 RNNCell 来进行定义的,

    在 RNN(Layer) 中的 compute_output_shape 函数可以查看到 RNN 输出维度的计算方法, 可以看出维度为 (输入维度, 输出维度) .代码如下:

        def compute_output_shape(self, input_shape):
            if isinstance(input_shape, list):
                input_shape = input_shape[0]
    
            if hasattr(self.cell.state_size, '__len__'):
                output_dim = self.cell.state_size[0]
            else:
                output_dim = self.cell.state_size
    
            if self.return_sequences:
                output_shape = (input_shape[0], input_shape[1], output_dim)
            else:
                output_shape = (input_shape[0], output_dim)
    
            if self.return_state:
                state_shape = [(input_shape[0], output_dim) for _ in self.states]
                return [output_shape] + state_shape
            else:
                return output_shape
    

    其中通过查看 LSTMCell 中的定义内容,如下:

        def __init__(self, units,
                     ....
                     **kwargs):
            super(LSTMCell, self).__init__(**kwargs)
            self.units = units
            self.activation = activations.get(activation)
            self.recurrent_activation = activations.get(recurrent_activation)
            self.use_bias = use_bias
                     ....
            self.dropout = min(1., max(0., dropout))
            self.recurrent_dropout = min(1., max(0., recurrent_dropout))
            self.implementation = implementation
            self.state_size = (self.units, self.units)
            self._dropout_mask = None
            self._recurrent_dropout_mask = None
    

    因此, 对于 LSTMCell 来说输出的 shape 即为 (input_shape[0], units, units), 在代码中可以看到 RNN 是通过 state 来管理当前 RNNLayer 使用哪个 LSTMCell 进行当前计算.

    在 RNNLayer 中存在 recurrent_kernel ,该只用来存放再传入下个 state 时使用的 kernel,

    相关文章

      网友评论

          本文标题:Keras RNN 源码分析

          本文链接:https://www.haomeiwen.com/subject/mhvrnxtx.html