美文网首页
Bi-LSTM with Attention (PyTorch

Bi-LSTM with Attention (PyTorch

作者: Jarkata | 来源:发表于2021-05-16 14:52 被阅读0次

这里用Bi-LSTM + Attention机制实现一个简单的句子分类任务。

先导包

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.utils.data as Data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定义Bi-LSTM(Attention)的参数

batch_size = 3
embedding_dim = 2
n_hidden = 5  # number of hidden units in one cell
num_classes = 2  # 0 or 1

处理数据,构造label


# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.

vocab = list(set(" ".join(sentences).split()))
word2idx = {w: i for i, w in enumerate(vocab)}
vocab_size = len(word2idx)

构建数据集


def make_data(sentences):
    inputs = []
    for sen in sentences:
        inputs.append(np.asarray([word2idx[n] for n in sen.split()]))

    targets = []
    for out in labels:
        targets.append(out)  # To using Torch Softmax Loss function

    return torch.LongTensor(inputs), torch.LongTensor(targets)


inputs, targets = make_data(sentences)
dataset = Data.TensorDataset(inputs, targets)
loader = Data.DataLoader(dataset, batch_size, True)

定义模型,注意这里的attention_net层,接受输入是lstm_output和final_state。根据lstm_output和final_state之间的相关性得到context向量和attention权重向量。之后将context向量经过一个全连接层,得到最终的分类结果。


class BiLSTM_Attention(nn.Module):
    def __init__(self):
        super(BiLSTM_Attention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)
        self.out = nn.Linear(n_hidden * 2, num_classes)


    def attention_net(self,lstm_output, final_state):
        # lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix
        # final_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]

        batch_size = len(lstm_output)
        # hidden = final_state.view(batch_size,-1,1)
        hidden = torch.cat((final_state[0],final_state[1]),dim=1).unsqueeze(2)
        # hidden : [batch_size, n_hidden * num_directions(=2), n_layer(=1)]
        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
        # attn_weights : [batch_size,n_step]
        soft_attn_weights = F.softmax(attn_weights,1)

        # context: [batch_size, n_hidden * num_directions(=2)]
        context = torch.bmm(lstm_output.transpose(1,2),soft_attn_weights.unsqueeze(2)).squeeze(2)

        return context, soft_attn_weights

    def forward(self,X):
        '''
        :param X: [batch_size, seq_len]
        :return:
        '''
        input = self.embedding(X) # input : [batch_size, seq_len, embedding_dim]
        input = input.transpose(0, 1) # input : [seq_len, batch_size, embedding_dim]

        # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
        # output : [seq_len, batch_size, n_hidden * num_directions(=2)]
        output, (final_hidden_state, final_cell_state) = self.lstm(input)
        output = output.transpose(0, 1) #output : [batch_size, seq_len, n_hidden * num_directions(=2)]


        attn_output, attention = self.attention_net(output,final_hidden_state)
        return self.out(attn_output),attention # attn_output : [batch_size, num_classes], attention : [batch_size, n_step]

model = BiLSTM_Attention().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练

for epoch in range(5000):
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred, attention = model(x)
        loss = criterion(pred,y)
        if (epoch + 1) % 1000 == 0:
            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

测试


test_text = 'i hate me'
tests = [np.asarray([word2idx[n] for n in test_text.split()])]
test_batch = torch.LongTensor(tests).to(device)

# Predict
predict, _ = model(test_batch)
predict = predict.data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
    print(test_text,"is Bad Mean...")
else:
    print(test_text,"is Good Mean!!")

查看Attention矩阵

fig = plt.figure(figsize=(6, 3)) # [batch_size, n_step]
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention.cpu().data, cmap='viridis')
ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)
ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], fontdict={'fontsize': 14})
plt.show()

相关文章

网友评论

      本文标题:Bi-LSTM with Attention (PyTorch

      本文链接:https://www.haomeiwen.com/subject/vxkqjltx.html