表示RNNCell
的抽象对象
继承: Layer
tf.keras.layers.AbstractRNNCell(
trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)
这是实现具有自定义行为的RNN单元的基类.
每个 RNNCell
必须具有以下属性并且实现 call
, (output, next_state) = call(input, state)
示例:
class MinimalRNNCell(AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)
@property
def state_size(self):
""" RNN Cell的单元个数 """
return self.units
def build(self, input_shape):
""" 构建网络 """
# 将输入维度映射到units维度参数矩阵
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
# 循环网络参数矩阵
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
""" 计算每个时间步的输出和状态 """
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, output
这个cell
的定义不同于文献中的定义。在文献中,cell
是指具有单个标量输出的对象。这个定义指的是这种单元构成的水平数组。
方法:
get_initial_state
get_initial_state(
inputs=None, batch_size=None, dtype=None
)
网友评论