美文网首页
[tf]模仿keras写可重用的层

[tf]模仿keras写可重用的层

作者: VanJordan | 来源:发表于2019-01-19 19:49 被阅读8次
    • __call__
    class LSTM(object):
      """LSTM layer using dynamic_rnn.
    
      Exposes variables in `trainable_weights` property.
      """
    
      def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'):
        self.cell_size = cell_size
        self.num_layers = num_layers
        self.keep_prob = keep_prob
        self.reuse = None
        self.trainable_weights = None
        self.name = name
    
      def __call__(self, x, initial_state, seq_length):
        
        with tf.variable_scope(self.name, reuse=self.reuse) as vs:
          cell = tf.contrib.rnn.MultiRNNCell([
              tf.contrib.rnn.BasicLSTMCell(
                  self.cell_size,
                  forget_bias=0.0,
                  reuse=tf.get_variable_scope().reuse)
              for _ in xrange(self.num_layers)
          ])
    
          lstm_out, next_state = tf.nn.dynamic_rnn(
              cell, x, initial_state=initial_state, sequence_length=seq_length)
    
          # shape(lstm_out) = (batch_size, timesteps, cell_size)
    
          if self.keep_prob < 1.:
            lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
    
          if self.reuse is None:
            self.trainable_weights = vs.global_variables()
    
        self.reuse = True
    
        return lstm_out, next_state
    
    • 使用keras
    class Actionselect(object):
    
      def __init__(self,
                   action_class,
                   **kwargs):
        self.multiclass_dense_layer = K.layers.Dense(action_class)  
        
      def __call__(self,input_data):
        return self.multiclass_dense_layer(input_data)
    

    相关文章

      网友评论

          本文标题:[tf]模仿keras写可重用的层

          本文链接:https://www.haomeiwen.com/subject/pmwzdqtx.html