美文网首页
tf.keras.layers.AbstractRNNCell

tf.keras.layers.AbstractRNNCell

作者: 又双叒叕苟了一天 | 来源:发表于2020-12-29 21:58 被阅读0次

表示RNNCell的抽象对象

继承: Layer

tf.keras.layers.AbstractRNNCell(
    trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)

这是实现具有自定义行为的RNN单元的基类.

每个 RNNCell必须具有以下属性并且实现 call , (output, next_state) = call(input, state)

示例:

class MinimalRNNCell(AbstractRNNCell):
    def __init__(self, units, **kwargs):
        self.units = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    @property
    def state_size(self):
        """ RNN Cell的单元个数 """
        return self.units

    def build(self, input_shape):
        """ 构建网络 """
        # 将输入维度映射到units维度参数矩阵
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel')
        # 循环网络参数矩阵
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        """ 计算每个时间步的输出和状态 """
        prev_output = states[0]
        h = K.dot(inputs, self.kernel)
        output = h + K.dot(prev_output, self.recurrent_kernel)
        return output, output

这个cell的定义不同于文献中的定义。在文献中,cell是指具有单个标量输出的对象。这个定义指的是这种单元构成的水平数组。

方法:

get_initial_state

get_initial_state(
    inputs=None, batch_size=None, dtype=None
)

相关文章

网友评论

      本文标题:tf.keras.layers.AbstractRNNCell

      本文链接:https://www.haomeiwen.com/subject/twgqoktx.html