美文网首页
attention-model 代码

attention-model 代码

作者: 此番风景 | 来源:发表于2017-02-24 21:39 被阅读3821次

    原文:https://github.com/codekansas/keras-language-modeling/blob/master/attention_lstm.py

    from __future__ import absolute_import
    
    from keras import backend as K
    from keras.engine import InputSpec
    from keras.layers import LSTM, activations, Wrapper
    
    
    class AttentionLSTM(LSTM):
        def __init__(self, output_dim, attention_vec, attn_activation='tanh', single_attention_param=False, **kwargs):
            self.attention_vec = attention_vec
            self.attn_activation = activations.get(attn_activation)
            self.single_attention_param = single_attention_param
    
            super(AttentionLSTM, self).__init__(output_dim, **kwargs)
    
        def build(self, input_shape):
            super(AttentionLSTM, self).build(input_shape)
    
            if hasattr(self.attention_vec, '_keras_shape'):
                attention_dim = self.attention_vec._keras_shape[1]
            else:
                raise Exception('Layer could not be build: No information about expected input shape.')
    
            self.U_a = self.inner_init((self.output_dim, self.output_dim),
                                       name='{}_U_a'.format(self.name))
            self.b_a = K.zeros((self.output_dim,), name='{}_b_a'.format(self.name))
    
            self.U_m = self.inner_init((attention_dim, self.output_dim),
                                       name='{}_U_m'.format(self.name))
            self.b_m = K.zeros((self.output_dim,), name='{}_b_m'.format(self.name))
    
            if self.single_attention_param:
                self.U_s = self.inner_init((self.output_dim, 1),
                                           name='{}_U_s'.format(self.name))
                self.b_s = K.zeros((1,), name='{}_b_s'.format(self.name))
            else:
                self.U_s = self.inner_init((self.output_dim, self.output_dim),
                                           name='{}_U_s'.format(self.name))
                self.b_s = K.zeros((self.output_dim,), name='{}_b_s'.format(self.name))
    
            self.trainable_weights += [self.U_a, self.U_m, self.U_s, self.b_a, self.b_m, self.b_s]
    
            if self.initial_weights is not None:
                self.set_weights(self.initial_weights)
                del self.initial_weights
    
        def step(self, x, states):
            h, [h, c] = super(AttentionLSTM, self).step(x, states)
            attention = states[4]
    
            m = self.attn_activation(K.dot(h, self.U_a) * attention + self.b_a)
            # Intuitively it makes more sense to use a sigmoid (was getting some NaN problems
            # which I think might have been caused by the exponential function -> gradients blow up)
            s = K.sigmoid(K.dot(m, self.U_s) + self.b_s)
    
            if self.single_attention_param:
                h = h * K.repeat_elements(s, self.output_dim, axis=1)
            else:
                h = h * s
    
            return h, [h, c]
    
        def get_constants(self, x):
            constants = super(AttentionLSTM, self).get_constants(x)
            constants.append(K.dot(self.attention_vec, self.U_m) + self.b_m)
            return constants
    
    
    class AttentionLSTMWrapper(Wrapper):
        def __init__(self, layer, attention_vec, attn_activation='tanh', single_attention_param=False, **kwargs):
            assert isinstance(layer, LSTM)
            self.supports_masking = True
            self.attention_vec = attention_vec
            self.attn_activation = activations.get(attn_activation)
            self.single_attention_param = single_attention_param
            super(AttentionLSTMWrapper, self).__init__(layer, **kwargs)
    
        def build(self, input_shape):
            assert len(input_shape) >= 3
            self.input_spec = [InputSpec(shape=input_shape)]
    
            if not self.layer.built:
                self.layer.build(input_shape)
                self.layer.built = True
    
            super(AttentionLSTMWrapper, self).build()
    
            if hasattr(self.attention_vec, '_keras_shape'):
                attention_dim = self.attention_vec._keras_shape[1]
            else:
                raise Exception('Layer could not be build: No information about expected input shape.')
    
            self.U_a = self.layer.inner_init((self.layer.output_dim, self.layer.output_dim), name='{}_U_a'.format(self.name))
            self.b_a = K.zeros((self.layer.output_dim,), name='{}_b_a'.format(self.name))
    
            self.U_m = self.layer.inner_init((attention_dim, self.layer.output_dim), name='{}_U_m'.format(self.name))
            self.b_m = K.zeros((self.layer.output_dim,), name='{}_b_m'.format(self.name))
    
            if self.single_attention_param:
                self.U_s = self.layer.inner_init((self.layer.output_dim, 1), name='{}_U_s'.format(self.name))
                self.b_s = K.zeros((1,), name='{}_b_s'.format(self.name))
            else:
                self.U_s = self.layer.inner_init((self.layer.output_dim, self.layer.output_dim), name='{}_U_s'.format(self.name))
                self.b_s = K.zeros((self.layer.output_dim,), name='{}_b_s'.format(self.name))
    
            self.trainable_weights = [self.U_a, self.U_m, self.U_s, self.b_a, self.b_m, self.b_s]
    
        def get_output_shape_for(self, input_shape):
            return self.layer.get_output_shape_for(input_shape)
    
        def step(self, x, states):
            h, [h, c] = self.layer.step(x, states)
            attention = states[4]
    
            m = self.attn_activation(K.dot(h, self.U_a) * attention + self.b_a)
            s = K.sigmoid(K.dot(m, self.U_s) + self.b_s)
    
            if self.single_attention_param:
                h = h * K.repeat_elements(s, self.layer.output_dim, axis=1)
            else:
                h = h * s
    
            return h, [h, c]
    
        def get_constants(self, x):
            constants = self.layer.get_constants(x)
            constants.append(K.dot(self.attention_vec, self.U_m) + self.b_m)
            return constants
    
        def call(self, x, mask=None):
            # input shape: (nb_samples, time (padded with zeros), input_dim)
            # note that the .build() method of subclasses MUST define
            # self.input_spec with a complete input shape.
            input_shape = self.input_spec[0].shape
            if K._BACKEND == 'tensorflow':
                if not input_shape[1]:
                    raise Exception('When using TensorFlow, you should define '
                                    'explicitly the number of timesteps of '
                                    'your sequences.\n'
                                    'If your first layer is an Embedding, '
                                    'make sure to pass it an "input_length" '
                                    'argument. Otherwise, make sure '
                                    'the first layer has '
                                    'an "input_shape" or "batch_input_shape" '
                                    'argument, including the time axis. '
                                    'Found input shape at layer ' + self.name +
                                    ': ' + str(input_shape))
            if self.layer.stateful:
                initial_states = self.layer.states
            else:
                initial_states = self.layer.get_initial_states(x)
            constants = self.get_constants(x)
            preprocessed_input = self.layer.preprocess_input(x)
    
            last_output, outputs, states = K.rnn(self.step, preprocessed_input,
                                                 initial_states,
                                                 go_backwards=self.layer.go_backwards,
                                                 mask=mask,
                                                 constants=constants,
                                                 unroll=self.layer.unroll,
                                                 input_length=input_shape[1])
            if self.layer.stateful:
                self.updates = []
                for i in range(len(states)):
                    self.updates.append((self.layer.states[i], states[i]))
    
            if self.layer.return_sequences:
                return outputs
            else:
                return last_output
    

    相关文章

      网友评论

          本文标题:attention-model 代码

          本文链接:https://www.haomeiwen.com/subject/yuxgwttx.html