keras实现Attention机制

作者: azim | 来源:发表于2018-08-28 18:18 被阅读3017次

    attention层的定义:(思路参考https://github.com/philipperemy/keras-attention-mechanism

    # Attention GRU network       
    class AttLayer(Layer):
        def __init__(self, **kwargs):
            self.init = initializations.get('normal')
            #self.input_spec = [InputSpec(ndim=3)]
            super(AttLayer, self).__init__(**kwargs)
    
        def build(self, input_shape):
            assert len(input_shape)==3
            #self.W = self.init((input_shape[-1],1))
            self.W = self.init((input_shape[-1],))
            #self.input_spec = [InputSpec(shape=input_shape)]
            self.trainable_weights = [self.W]
            super(AttLayer, self).build(input_shape)  # be sure you call this somewhere!
    
        def call(self, x, mask=None):
            eij = K.tanh(K.dot(x, self.W))
            
            ai = K.exp(eij)
            weights = ai/K.sum(ai, axis=1).dimshuffle(0,'x')
            
            weighted_input = x*weights.dimshuffle(0,1,'x')
            return weighted_input.sum(axis=1)
    
        def get_output_shape_for(self, input_shape):
            return (input_shape[0], input_shape[-1])
    

    具体的用法:

    input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
    embedded_sequences = embedding_layer(input)
    l_lstm = Bidirectional(LSTM(100, return_sequences=True))(embedded_sequences)
    l_att = AttLayer()(l_lstm)
    preds = Dense(2, activation='softmax')(l_att)
    model = Model(sequence_input, preds)
    model.compile(loss='categorical_crossentropy',
                 optimizer='rmsprop',
                 metrics=['acc'])
    
    print("model fitting - attention GRU network")
    model.summary()
    model.fit(x_train, y_train, validation_data=(x_val, y_val),
             nb_epoch=10, batch_size=50)
    

    相关文章

      网友评论

        本文标题:keras实现Attention机制

        本文链接:https://www.haomeiwen.com/subject/vcbqwftx.html