美文网首页
Transformer_复现_多头注意力机制

Transformer_复现_多头注意力机制

作者: 温柔倾怀 | 来源:发表于2023-07-11 21:06 被阅读0次
    import os
    import torch
    import torch.nn as nn
    from torch.utils.data import  Dataset,DataLoader
    from tqdm import tqdm
    
    def read_data(file_path,num=None):
        with open(file_path,"r",encoding="utf-8") as f:
            all_data = f.read().split("\n")
    
        all_text = []
        all_label = []
    
        for data in all_data:
            s_data = data.split("\t")
            if len(s_data) != 2:
                continue
            text,label = s_data
            all_text.append(text)
            all_label.append(int(label))
        if num:
            return all_text[:num],all_label[:num]
        else:
            return all_text, all_label
    
    def get_word_2_index(all_data):
        word_2_index = {"[PAD]":0,"[UNK]":1}
    
        for data in all_data:
            for w in data:
                word_2_index[w] = word_2_index.get(w,len(word_2_index))
        return word_2_index
    
    class MyDataset(Dataset):
        def __init__(self,all_text,all_label):
            self.all_text = all_text
            self.all_label = all_label
    
        def __len__(self):
            assert len(self.all_text) == len(self.all_label)
            return len(self.all_label)
    
        def __getitem__(self, index):
    
            text = self.all_text[index][:512]
            label = self.all_label[index]
    
            text_index = [word_2_index.get(i,1) for i in text]
    
            return text_index,label,len(text_index)
    
    def coll_fn(batch_data):
    
        batch_text_idx,batch_label,batch_len = zip(*batch_data)
        batch_max_len = max(batch_len)
    
        batch_text_idx_new = []
    
        for text_idx,label,len_ in zip(batch_text_idx,batch_label,batch_len):
            text_idx = text_idx + [0] * (batch_max_len-len_)
            batch_text_idx_new.append(text_idx)
    
        return torch.tensor(batch_text_idx_new),torch.tensor(batch_label),torch.tensor(batch_len)
    
    class Positional(nn.Module):
        def __init__(self):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
    
        def forward(self, x):
            return self.embedding(x)
    
    # befor
    class MultiHeadAttention(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(emb_num, emb_num)
            self.gelu = nn.GELU()
    
        def forward(self, x):
            return self.gelu(self.linear(x))
    
    
    class Multi_Head_attention(nn.Module):
        def __init__(self):
            super().__init__()
            self.Q = nn.Linear(embedding_dim,embedding_dim)
            self.K = nn.Linear(embedding_dim,embedding_dim)
            self.V = nn.Linear(embedding_dim,embedding_dim)
    
        def forward(self, x):
            b,s,n = x.shape
            # x_ = x.reshape(b,s,head_num,-1)
            Q = self.Q(x).reshape(b,s,head_num,-1).transpose(1,2)
            K = self.K(x).reshape(b,s,head_num,-1).transpose(1,2)
            V = self.V(x).reshape(b,s,head_num,-1).transpose(1,2)
    
            score = torch.softmax(Q @ K.transpose(-1,-2) / 10,dim=-2)
            out = score @ V
    
            # out1 = out.reshape(b,s,n)
            out = out.transpose(1,2).reshape(b,s,n)
            return out
    
    class Norm(nn.Module):
        def __init__(self):
            super().__init__()
            self.norm = nn.LayerNorm(embedding_dim)
    
        def forward(self, x):
            return self.norm(x)
    
    class Feed_Forward(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear1 = nn.Linear(embedding_dim,feed_num)
            self.gelu = nn.GELU()
            self.linear2 = nn.Linear(feed_num,embedding_dim)
            self.norm = nn.LayerNorm(embedding_dim)
            self.dropout = nn.Dropout(0.1)
    
        def forward(self, x):
            linear1_out = self.linear1(x)
            gelu_out = self.gelu(linear1_out)
            linear2_out = self.linear2(gelu_out)
            norm_out = self.norm(linear2_out)
            return self.dropout(norm_out)
    
    class Block(nn.Module):
        def __init__(self):
            super().__init__()
            self.multi_head_attention = Multi_Head_attention()
            self.norm1 = Norm()
            self.feed_forward = Feed_Forward()
            self.norm2 = Norm()
    
        def forward(self,x):
    
            att_x = self.multi_head_attention.forward(x)
            norm1 = self.norm1.forward(att_x)
    
            adn_out1 = x + norm1
    
            ff_out = self.feed_forward.forward(adn_out1)
            norm2 = self.norm2.forward(ff_out)
    
            adn_out2 = adn_out1 + norm2
    
            return adn_out2
    
    class MyTransformer(nn.Module):
        def __init__(self):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size,embedding_dim)
            self.positional_layer = Positional()
    
            self.encoder = nn.Sequential(*[Block() for _ in range(num_hidden_layers)])
    
            self.cls = nn.Linear(embedding_dim,num_class)
            self.loss_fun = nn.CrossEntropyLoss()
    
        def forward(self,inputs,labels=None):
    
            input_emb= self.embedding(inputs)
            position_emb = self.positional_layer(inputs)
    
            input_embeddings = (input_emb + position_emb)
    
            encoder_out = self.encoder.forward(input_embeddings)
    
            # sentence_f = torch.mean(encoder_out,dim=1)
            sentence_f,_ = torch.max(encoder_out,dim=1)
    
            pre = self.cls(sentence_f)
    
            if labels is not None:
                loss = self.loss_fun(pre,labels)
                return loss
            return torch.argmax(pre,dim=-1)
    
    
    if __name__ == "__main__":
        os.chdir(os.path.dirname(os.path.abspath(__file__)))
    
        train_text,train_label = read_data(os.path.join("..","data","文本分类","train.txt"))
        test_text,test_label = read_data(os.path.join("..","data","文本分类","test.txt"))
        word_2_index = get_word_2_index(train_text)
    
        vocab_size = len(word_2_index)
    
        index_2_word = list(word_2_index)
    
        batch_size = 200
        epoch = 10
        embedding_dim = 256
        feed_num = 256
        num_hidden_layers = 2
        head_num = 2
        lr = 0.0001
        device = "cuda" if torch.cuda.is_available() else "cpu"
        num_class = len(set(train_label))
    
        train_dataset = MyDataset(train_text,train_label)
        train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,collate_fn=coll_fn)
    
        test_dataset = MyDataset(test_text, test_label)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=coll_fn)
    
        model = MyTransformer().to(device)
        opt = torch.optim.Adam(model.parameters(),lr=lr)
    
        for e in range(epoch):
            for batch_text_idx,batch_label,batch_len in tqdm(train_dataloader):
                batch_text_idx = batch_text_idx.to(device)
                batch_label = batch_label.to(device)
                loss = model.forward(batch_text_idx,batch_label)
    
                loss.backward()
                opt.step()
                opt.zero_grad()
            # print(f"loss:{loss:.3f}")
    
            right_num = 0
            for batch_text_idx,batch_label,batch_len in tqdm(test_dataloader):
                batch_text_idx = batch_text_idx.to(device)
                batch_label = batch_label.to(device)
                pre = model.forward(batch_text_idx)
    
                right_num += int(torch.sum(batch_label == pre ))
    
            acc = right_num / len(test_dataset)
            print(f"acc:{acc:.3f}")
    

    相关文章

      网友评论

          本文标题:Transformer_复现_多头注意力机制

          本文链接:https://www.haomeiwen.com/subject/dyjuudtx.html