美文网首页
tf.keras.layers.AbstractRNNCell

tf.keras.layers.AbstractRNNCell

作者: 又双叒叕苟了一天 | 来源:发表于2020-12-29 21:58 被阅读0次

    表示RNNCell的抽象对象

    继承: Layer

    tf.keras.layers.AbstractRNNCell(
        trainable=True, name=None, dtype=None, dynamic=False, **kwargs
    )
    

    这是实现具有自定义行为的RNN单元的基类.

    每个 RNNCell必须具有以下属性并且实现 call , (output, next_state) = call(input, state)

    示例:

    class MinimalRNNCell(AbstractRNNCell):
        def __init__(self, units, **kwargs):
            self.units = units
            super(MinimalRNNCell, self).__init__(**kwargs)
    
        @property
        def state_size(self):
            """ RNN Cell的单元个数 """
            return self.units
    
        def build(self, input_shape):
            """ 构建网络 """
            # 将输入维度映射到units维度参数矩阵
            self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                          initializer='uniform',
                                          name='kernel')
            # 循环网络参数矩阵
            self.recurrent_kernel = self.add_weight(
                shape=(self.units, self.units),
                initializer='uniform',
                name='recurrent_kernel')
            self.built = True
    
        def call(self, inputs, states):
            """ 计算每个时间步的输出和状态 """
            prev_output = states[0]
            h = K.dot(inputs, self.kernel)
            output = h + K.dot(prev_output, self.recurrent_kernel)
            return output, output
    

    这个cell的定义不同于文献中的定义。在文献中,cell是指具有单个标量输出的对象。这个定义指的是这种单元构成的水平数组。

    方法:

    get_initial_state

    get_initial_state(
        inputs=None, batch_size=None, dtype=None
    )
    

    相关文章

      网友评论

          本文标题:tf.keras.layers.AbstractRNNCell

          本文链接:https://www.haomeiwen.com/subject/twgqoktx.html