美文网首页
Tensorflow中双向RNN的实现

Tensorflow中双向RNN的实现

作者: JasonWayne | 来源:发表于2019-06-09 10:18 被阅读0次
@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
                        "keras.layers.RNN(cell))`, which is equivalent to "
                        "this API")
@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
                              initial_state_fw=None, initial_state_bw=None,
                              dtype=None, parallel_iterations=None,
                              swap_memory=False, time_major=False, scope=None):
  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)

  with vs.variable_scope(scope or "bidirectional_rnn"):
    # Forward direction
    with vs.variable_scope("fw") as fw_scope:
      output_fw, output_state_fw = dynamic_rnn(
          cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
          initial_state=initial_state_fw, dtype=dtype,
          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
          time_major=time_major, scope=fw_scope)

    # Backward direction
    if not time_major:
      time_axis = 1
      batch_axis = 0
    else:
      time_axis = 0
      batch_axis = 1

    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
      if seq_lengths is not None:
        return array_ops.reverse_sequence(
            input=input_, seq_lengths=seq_lengths,
            seq_axis=seq_axis, batch_axis=batch_axis)
      else:
        return array_ops.reverse(input_, axis=[seq_axis])

    with vs.variable_scope("bw") as bw_scope:

      def _map_reverse(inp):
        return _reverse(
            inp,
            seq_lengths=sequence_length,
            seq_axis=time_axis,
            batch_axis=batch_axis)

      inputs_reverse = nest.map_structure(_map_reverse, inputs)
      tmp, output_state_bw = dynamic_rnn(
          cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
          initial_state=initial_state_bw, dtype=dtype,
          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
          time_major=time_major, scope=bw_scope)

  output_bw = _reverse(
      tmp, seq_lengths=sequence_length,
      seq_axis=time_axis, batch_axis=batch_axis)

  outputs = (output_fw, output_bw)
  output_states = (output_state_fw, output_state_bw)

  return (outputs, output_states)

重点看下面这一些

  
  with vs.variable_scope(scope or "bidirectional_rnn"):
    # Forward direction
    with vs.variable_scope("fw") as fw_scope:
 # 计算正向的结果
      output_fw, output_state_fw = dynamic_rnn(
          cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
          initial_state=initial_state_fw, dtype=dtype,
          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
          time_major=time_major, scope=fw_scope)

 .......
# 将输入序列翻转,计算反向序列的结果
  output_bw = _reverse(
      tmp, seq_lengths=sequence_length,
      seq_axis=time_axis, batch_axis=batch_axis)

# 将正向,反向结果拼接起来
  outputs = (output_fw, output_bw)
  output_states = (output_state_fw, output_state_bw)

  return (outputs, output_states)

相关文章

网友评论

      本文标题:Tensorflow中双向RNN的实现

      本文链接:https://www.haomeiwen.com/subject/addfxctx.html