美文网首页
Keras RNN 源码分析

Keras RNN 源码分析

作者: 走在成长的道路上 | 来源:发表于2018-05-18 12:06 被阅读0次

在 keras 源码中, layers/recurrent.py 中看到 RNN 实现方式

RNN 中的循环体使用 RNNCell 来进行定义的,

在 RNN(Layer) 中的 compute_output_shape 函数可以查看到 RNN 输出维度的计算方法, 可以看出维度为 (输入维度, 输出维度) .代码如下:

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shape = input_shape[0]

        if hasattr(self.cell.state_size, '__len__'):
            output_dim = self.cell.state_size[0]
        else:
            output_dim = self.cell.state_size

        if self.return_sequences:
            output_shape = (input_shape[0], input_shape[1], output_dim)
        else:
            output_shape = (input_shape[0], output_dim)

        if self.return_state:
            state_shape = [(input_shape[0], output_dim) for _ in self.states]
            return [output_shape] + state_shape
        else:
            return output_shape

其中通过查看 LSTMCell 中的定义内容,如下:

    def __init__(self, units,
                 ....
                 **kwargs):
        super(LSTMCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias
                 ....
        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        self.implementation = implementation
        self.state_size = (self.units, self.units)
        self._dropout_mask = None
        self._recurrent_dropout_mask = None

因此, 对于 LSTMCell 来说输出的 shape 即为 (input_shape[0], units, units), 在代码中可以看到 RNN 是通过 state 来管理当前 RNNLayer 使用哪个 LSTMCell 进行当前计算.

在 RNNLayer 中存在 recurrent_kernel ,该只用来存放再传入下个 state 时使用的 kernel,

相关文章

网友评论

      本文标题:Keras RNN 源码分析

      本文链接:https://www.haomeiwen.com/subject/mhvrnxtx.html