美文网首页
RNN、LSTM TF源码

RNN、LSTM TF源码

作者: 小透明苞谷 | 来源:发表于2019-06-20 11:28 被阅读0次

    RNN

    class BasicRNNCell(RNNCell):
      """The most basic RNN cell.
      Args:
        num_units: int, The number of units in the RNN cell.
        activation: Nonlinearity to use.  Default: `tanh`.
        reuse: (optional) Python boolean describing whether to reuse variables
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
      """
    
      def __init__(self, num_units, activation=None, reuse=None):
        super(BasicRNNCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh
        self._linear = None
    
      @property
      def state_size(self):
        return self._num_units
    
      @property
      def output_size(self):
        return self._num_units
    
      def call(self, inputs, state):
        """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
        if self._linear is None:
          self._linear = _Linear([inputs, state], self._num_units, True)
    
        output = self._activation(self._linear([inputs, state]))
        return output, output
      
    cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
    print(cell.state_size)
    inputs = tf.placeholder(tf.float32, shape=[32, 100])
    h0 = cell.zero_state(32, tf.float32)
    output, h1 = cell(inputs=inputs, state=h0)
    print(output.shape) #128
    print(h1.shape)         #128
    
    #这里我们首先初始化了一个神经元个数为 128 的 BasicRNNCell 类,然后构造了一个 shape 为 [32, 100] 的变量作为 inputs,其代表 batch_size 为 32, 维度为 100,随后初始化了初始隐藏状态,调用了 zero_state() 方法,最终调用了其 call() 方法,最后得到 output 和 h1
    

    LSTM

    class BasicRNNCell(RNNCell):
      def __init__(self, num_units, forget_bias=1.0,
                     state_is_tuple=True, activation=None, reuse=None):
          super(BasicLSTMCell, self).__init__(_reuse=reuse)
          if not state_is_tuple:
            logging.warn("%s: Using a concatenated state is slower and will soon be "
                         "deprecated.  Use state_is_tuple=True.", self)
          self._num_units = num_units
          self._forget_bias = forget_bias
          self._state_is_tuple = state_is_tuple
          self._activation = activation or math_ops.tanh
          self._linear = None
          
      @property
      def state_size(self):
          return (LSTMStateTuple(self._num_units, self._num_units)
              if self._state_is_tuple else 2 * self._num_units)
    
      @property
      def output_size(self):
          return self._num_units
        
      def call(self, inputs, state):
          """Long short-term memory cell (LSTM).
    
          Args:
            inputs: `2-D` tensor with shape `[batch_size x input_size]`.
            state: An `LSTMStateTuple` of state tensors, each shaped
              `[batch_size x self.state_size]`, if `state_is_tuple` has been set to
              `True`.  Otherwise, a `Tensor` shaped
              `[batch_size x 2 * self.state_size]`.
    
          Returns:
            A pair containing the new hidden state, and the new state (either a
              `LSTMStateTuple` or a concatenated state, depending on
              `state_is_tuple`).
          """
          sigmoid = math_ops.sigmoid
          # Parameters of gates are concatenated into one multiply for efficiency.
          if self._state_is_tuple:
              c, h = state
          else:
              c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
    
          if self._linear is None:
              self._linear = _Linear([inputs, h], 4 * self._num_units, True)
          # i = input_gate, j = new_input, f = forget_gate, o = output_gate
          i, j, f, o = array_ops.split(
              value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)
    
          new_c = (
              c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
          new_h = self._activation(new_c) * sigmoid(o)
    
          if self._state_is_tuple:
              new_state = LSTMStateTuple(new_c, new_h)
          else:
              new_state = array_ops.concat([new_c, new_h], 1)
          return new_h, new_state
        
    cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
    inputs = tf.placeholder(tf.float32, shape=(32, 100))
    h0 = cell.zero_state(32, tf.float32)
    output, h1 = cell(inputs=inputs, state=h0)
    

    摘自:https://cuiqingcai.com/4925.html

    相关文章

      网友评论

          本文标题:RNN、LSTM TF源码

          本文链接:https://www.haomeiwen.com/subject/sujdqctx.html