@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
"keras.layers.RNN(cell))`, which is equivalent to "
"this API")
@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
initial_state_fw=None, initial_state_bw=None,
dtype=None, parallel_iterations=None,
swap_memory=False, time_major=False, scope=None):
rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
initial_state=initial_state_fw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=fw_scope)
# Backward direction
if not time_major:
time_axis = 1
batch_axis = 0
else:
time_axis = 0
batch_axis = 1
def _reverse(input_, seq_lengths, seq_axis, batch_axis):
if seq_lengths is not None:
return array_ops.reverse_sequence(
input=input_, seq_lengths=seq_lengths,
seq_axis=seq_axis, batch_axis=batch_axis)
else:
return array_ops.reverse(input_, axis=[seq_axis])
with vs.variable_scope("bw") as bw_scope:
def _map_reverse(inp):
return _reverse(
inp,
seq_lengths=sequence_length,
seq_axis=time_axis,
batch_axis=batch_axis)
inputs_reverse = nest.map_structure(_map_reverse, inputs)
tmp, output_state_bw = dynamic_rnn(
cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
initial_state=initial_state_bw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=bw_scope)
output_bw = _reverse(
tmp, seq_lengths=sequence_length,
seq_axis=time_axis, batch_axis=batch_axis)
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
return (outputs, output_states)
重点看下面这一些
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
with vs.variable_scope("fw") as fw_scope:
# 计算正向的结果
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
initial_state=initial_state_fw, dtype=dtype,
parallel_iterations=parallel_iterations, swap_memory=swap_memory,
time_major=time_major, scope=fw_scope)
.......
# 将输入序列翻转,计算反向序列的结果
output_bw = _reverse(
tmp, seq_lengths=sequence_length,
seq_axis=time_axis, batch_axis=batch_axis)
# 将正向,反向结果拼接起来
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
return (outputs, output_states)
网友评论