Pytorch学习记录-Transformer(模型实现)

作者: 我的昵称违规了 | 来源:发表于2019-05-03 10:37 被阅读8次

Pytorch学习记录-torchtext和Pytorch的实例4

0. PyTorch Seq2Seq项目介绍

在完成基本的torchtext之后,找到了这个教程,《基于Pytorch和torchtext来理解和实现seq2seq模型》。
这个项目主要包括了6个子项目

  1. 使用神经网络训练Seq2Seq
  2. 使用RNN encoder-decoder训练短语表示用于统计机器翻译
  3. 使用共同学习完成NMT的堆砌和翻译
  4. 打包填充序列、掩码和推理
  5. 卷积Seq2Seq
  6. Transformer

6. Transformer

OK,来到最后一章,Transformer,又回到这个模型啦,绕不开的,依旧没有讲解,只能看看代码。
来源不用说了,《Attention is all you need》。Transformer在之前复习了多次,这次也一样,不知道教程会如何实现,反正之前学得挺痛苦的。

6.1 准备数据

这里使用了一个新的数据集TranslationDataset,机器翻译数据集是 TranslationDataset 类的子类。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchtext
#机器翻译数据集是 TranslationDataset 类的子类。
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator

import spacy

import random
import math
import os
import time

SEED=1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic=True

spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]
SRC=Field(tokenize=tokenize_de,
         init_token='<sos>',
         eos_token='<eos>',
         lower=True,
         batch_first=True)
TRG=Field(tokenize=tokenize_en,
         init_token='<sos>',
         eos_token='<eos>',
         lower=True,
         batch_first=True)
train_data,valid_data,test_data=Multi30k.splits(
    exts=('.de','.en'),
    fields=(SRC, TRG)
)
SRC.build_vocab(train_data,min_freq=2)
TRG.build_vocab(train_data,min_freq=2)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE=128
train_iterator, valid_iterator, test_iterator=BucketIterator.splits(
    (train_data,valid_data,test_data),
    batch_size=BATCH_SIZE,
    device=device
)
cuda

6.2 构建模型

Transformer结构图

6.2.1 encoder和decoder

6.2.2 使用多种attention机制(multi-head context-attention、multi-head self-attention)

6.2.2.1 multi head self-attention

6.2.2.2 multi head context-attention

6.2.2.3 如何实现Attention?

6.2.2.4 如何实现multi-heads attention?

6.2.3 使用Layer-Normalization机制

6.2.4.1 padding mask

6.2.4.2 sequence mask

6.2.5 使用残差residual connection

6.2.6 使用Positional-encoding

6.2.7 Position-wise Feed-Forward network

6.3.1 Encoder

照例是Encoder部分,包括了Encoder,EncoderLayer,SelfAttention,PositionwiseFeedforward四个部分

class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention, positionwise_feedforward, dropout, device):
        super(Encoder, self).__init__()
        
        self.input_dim=input_dim
        self.hid_dim=hid_dim
        self.n_layers=n_layers
        self.n_heads=n_heads
        self.pf_dim=pf_dim
        self.encoder_layer=encoder_layer
        self.self_attention=self_attention
        self.positionwise_feedforward=positionwise_feedforward
        self.dropout=dropout
        self.device=device
        
        self.tok_embedding=nn.Embedding(input_dim, hid_dim)
        self.pos_embedding=nn.Embedding(1000,hid_dim)
        
        self.layers=nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device) for _ in range(n_layers)])
        
        self.do=nn.Dropout(dropout)
        self.scale=torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
        #src = [batch size, src sent len]
        #src_mask = [batch size, src sent len]
        pos=torch.arange(0,src.shape[1]).unsqueeze(0).repeat(src.shape[0],1).to(self.device)
        src=self.do((self.tok_embedding(src)*self.scale)+self.pos_embedding(pos))
        #src = [batch size, src sent len, hid dim]
        
        for layer in self.layers:
            src=layer(src, src_mask)
            
        return src
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, postionwise_feedforward,dropout,device):
        super(EncoderLayer,self).__init__()
        
        self.ln=nn.LayerNorm(hid_dim)
        self.sa=self_attention(hid_dim,n_heads,dropout,device)
        self.pf=postionwise_feedforward(hid_dim, pf_dim,dropout)
        self.do=nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        #src = [batch size, src sent len, hid dim]
        #src_mask = [batch size, src sent len]
        src=self.ln(src+self.do(self.sa(src,src,src,src_mask)))
        src=self.ln(src+self.do(self.pf(src)))
        
        return src
class SelfAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super(SelfAttention,self).__init__()
        
        self.hid_dim=hid_dim
        self.n_heads=n_heads
        
        assert hid_dim%n_heads==0
        
        self.w_q=nn.Linear(hid_dim,hid_dim)
        self.w_k=nn.Linear(hid_dim, hid_dim)
        self.w_v=nn.Linear(hid_dim, hid_dim)
        
        self.fc=nn.Linear(hid_dim,hid_dim)
        self.do=nn.Dropout(dropout)
        
        self.scale=torch.sqrt(torch.FloatTensor([hid_dim//n_heads])).to(device)
    def forward(self, query, key, value, mask=None):
        bsz=query.shape[0]
        #query = key = value [batch size, sent len, hid dim]
        
        Q=self.w_q(query)
        K=self.w_k(key)
        V=self.w_v(value)
        #Q, K, V = [batch size, sent len, hid dim]
        
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        #Q, K, V = [batch size, n heads, sent len, hid dim // n heads]
        
        # 实现attentionQ*K^T/D
        energy=torch.matmul(Q,K.permute(0,1,3,2))/self.scale
        #energy = [batch size, n heads, sent len, sent len]
        
        if mask is not None:
            energy=energy.masked_fill(mask==0, -1e10)
        # 实现softmax部分
        attention=self.do(F.softmax(energy, dim=-1))
        #attention = [batch size, n heads, sent len, sent len]
        
        x=torch.matmul(attention,V)
        #x = [batch size, n heads, sent len, hid dim // n heads]
        x=x.permute(0,2,1,3).contiguous()
        #x = [batch size, sent len, n heads, hid dim // n heads]
        
        x=x.view(bsz, -1, self.n_heads*(self.hid_dim//self.n_heads))
        #x = [batch size, src sent len, hid dim]
        
        x=self.fc(x)
        
        return x
        
class PositionwiseFeedforward(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super(PositionwiseFeedforward,self).__init__()
        
        self.hid_dim=hid_dim
        self.pf_dim=pf_dim
        
        self.fc_1=nn.Conv1d(hid_dim,pf_dim,1)
        self.fc_2=nn.Conv1d(pf_dim, hid_dim, 1)
        
        self.do=nn.Dropout(dropout)
        
    def forward(self,x):
        #x = [batch size, sent len, hid dim]
        x = x.permute(0, 2, 1)
        #x = [batch size, hid dim, sent len]
        
        x = self.do(F.relu(self.fc_1(x)))
        #x = [batch size, ff dim, sent len]
        
        x = self.fc_2(x)
        #x = [batch size, hid dim, sent len]
        
        x = x.permute(0, 2, 1)
        #x = [batch size, sent len, hid dim]
        
        return x

6.3.2 Decoder

Decoder部分包括Decoder,DecoderLayer两个部分

class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim,n_layers,n_heads,pf_dim,decoder_layer,self_attention,positionwise_feedforward,dropout,device):
        super(Decoder,self).__init__()
        
        self.output_dim=output_dim
        self.hid_dim=hid_dim
        self.n_layers=n_layers
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.decoder_layer = decoder_layer
        self.self_attention = self_attention
        self.positionwise_feedforward = positionwise_feedforward
        self.dropout = dropout
        self.device = device
        
        self.tok_embedding=nn.Embedding(output_dim, hid_dim)
        self.pos_embedding=nn.Embedding(1000,hid_dim)
        
        self.layers=nn.ModuleList([decoder_layer(hid_dim,n_heads,pf_dim,self_attention,positionwise_feedforward,dropout,device) for _ in range(n_layers)])
        self.fc=nn.Linear(hid_dim, output_dim)
        self.do=nn.Dropout(dropout)
        self.scale=torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, src, trg_mask, src_mask):
        #trg = [batch_size, trg sent len]
        #src = [batch_size, src sent len]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
        pos=torch.arange(0, trg.shape[1]).unsqueeze(0).repeat(trg.shape[0], 1).to(self.device)  
        trg=self.do((self.tok_embedding(trg)*self.scale)+self.pos_embedding(pos))
        
        for layer in self.layers:
            trg=layer(trg,src,trg_mask,src_mask)
        
        return self.fc(trg)
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device):
        super().__init__()
        
        self.ln = nn.LayerNorm(hid_dim)
        self.sa = self_attention(hid_dim, n_heads, dropout, device)
        self.ea = self_attention(hid_dim, n_heads, dropout, device)
        self.pf = positionwise_feedforward(hid_dim, pf_dim, dropout)
        self.do = nn.Dropout(dropout)
        
    def forward(self, trg, src, trg_mask, src_mask):
        
        #trg = [batch size, trg sent len, hid dim]
        #src = [batch size, src sent len, hid dim]
        #trg_mask = [batch size, trg sent len]
        #src_mask = [batch size, src sent len]
                
        trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask)))
                
        trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask)))
        
        trg = self.ln(trg + self.do(self.pf(trg)))
        
        return trg

6.3.3 模型整合

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device
        
    def make_masks(self, src, trg):
        
        #src = [batch size, src sent len]
        #trg = [batch size, trg sent len]
        
        src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
        
        trg_pad_mask = (trg != self.pad_idx).unsqueeze(1).unsqueeze(3)

        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), dtype=torch.uint8, device=self.device))
        
        trg_mask = trg_pad_mask & trg_sub_mask
        
        return src_mask, trg_mask
    def forward(self, src, trg):
        #src = [batch size, src sent len]
        #trg = [batch size, trg sent len]
                
        src_mask, trg_mask = self.make_masks(src, trg)
        enc_src = self.encoder(src, src_mask)
        #enc_src = [batch size, src sent len, hid dim]
                
        out = self.decoder(trg, enc_src, trg_mask, src_mask)
        #out = [batch size, trg sent len, output dim]
        
        return out
input_dim=len(SRC.vocab)
hid_dim=512
n_layers=6
n_heads=8
pf_dim=2048
dropout=0.1
enc=Encoder(input_dim,hid_dim,n_layers,n_heads,pf_dim,EncoderLayer,SelfAttention,PositionwiseFeedforward,dropout,device)
output_dim=len(TRG.vocab)
hid_dim=512
n_layers=6
n_heads=8
pf_dim=2048
dropout=0.1
dec=Decoder(output_dim,hid_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention, PositionwiseFeedforward, dropout, device)
pad_idx=SRC.vocab.stoi['<pad>']
model=Seq2Seq(enc,dec,pad_idx,device).to(device)
model

这部分是模型结构输出,可以看到Encoder和Decoder的结构,建议和前面的图进行一次对比。

Transformer结构图
    Seq2Seq(
      (encoder): Encoder(
        (tok_embedding): Embedding(7855, 512)
        (pos_embedding): Embedding(1000, 512)
        (layers): ModuleList(
          (0): EncoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (1): EncoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (2): EncoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (3): EncoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (4): EncoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (5): EncoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
        )
        (do): Dropout(p=0.1)
      )
      (decoder): Decoder(
        (tok_embedding): Embedding(5893, 512)
        (pos_embedding): Embedding(1000, 512)
        (layers): ModuleList(
          (0): DecoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (ea): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (1): DecoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (ea): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (2): DecoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (ea): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (3): DecoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (ea): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (4): DecoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (ea): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
          (5): DecoderLayer(
            (ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
            (sa): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (ea): SelfAttention(
              (w_q): Linear(in_features=512, out_features=512, bias=True)
              (w_k): Linear(in_features=512, out_features=512, bias=True)
              (w_v): Linear(in_features=512, out_features=512, bias=True)
              (fc): Linear(in_features=512, out_features=512, bias=True)
              (do): Dropout(p=0.1)
            )
            (pf): PositionwiseFeedforward(
              (fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
              (fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
              (do): Dropout(p=0.1)
            )
            (do): Dropout(p=0.1)
          )
        )
        (fc): Linear(in_features=512, out_features=5893, bias=True)
        (do): Dropout(p=0.1)
      )
    )

6.4.1 参数设置

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 55,206,149 trainable parameters
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
optimizer = NoamOpt(hid_dim, 1, 2000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

6.4.2 模型训练

def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src = batch.src
        trg = batch.trg
        
        optimizer.optimizer.zero_grad()
        
        output = model(src, trg[:,:-1])
                
        #output = [batch size, trg sent len - 1, output dim]
        #trg = [batch size, trg sent len]
            
        output = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:,1:].contiguous().view(-1)
                
        #output = [batch size * trg sent len - 1, output dim]
        #trg = [batch size * trg sent len - 1]
            
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg[:,:-1])
            
            #output = [batch size, trg sent len - 1, output dim]
            #trg = [batch size, trg sent len]
            
            output = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:,1:].contiguous().view(-1)
            
            #output = [batch size * trg sent len - 1, output dim]
            #trg = [batch size * trg sent len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs
N_EPOCHS = 10
CLIP = 1
SAVE_DIR = 'models'
MODEL_SAVE_PATH = os.path.join(SAVE_DIR, 'transformer-seq2seq.pt')

best_valid_loss = float('inf')

if not os.path.isdir(f'{SAVE_DIR}'):
    os.makedirs(f'{SAVE_DIR}')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
    
    print(f'| Epoch: {epoch+1:03} | Time: {epoch_mins}m {epoch_secs}s| Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f} | Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f} |')

大概跑了一下结果,真不知道教程用什么硬件运行的,才40多秒的速度……

| Epoch: 001 | Time: 1m 44s| Train Loss: 5.924 | Train PPL: 373.732 | Val. Loss: 4.119 | Val. PPL: 61.478 |
| Epoch: 002 | Time: 1m 48s| Train Loss: 3.778 | Train PPL: 43.709 | Val. Loss: 3.177 | Val. PPL: 23.976 |
| Epoch: 003 | Time: 1m 48s| Train Loss: 3.133 | Train PPL: 22.939 | Val. Loss: 2.812 | Val. PPL: 16.645 |
| Epoch: 004 | Time: 1m 48s| Train Loss: 2.763 | Train PPL: 15.846 | Val. Loss: 2.611 | Val. PPL: 13.615 |
| Epoch: 005 | Time: 1m 47s| Train Loss: 2.500 | Train PPL: 12.183 | Val. Loss: 2.421 | Val. PPL: 11.260 |
| Epoch: 006 | Time: 1m 48s| Train Loss: 2.310 | Train PPL: 10.073 | Val. Loss: 2.334 | Val. PPL: 10.318 |

相关文章

网友评论

    本文标题:Pytorch学习记录-Transformer(模型实现)

    本文链接:https://www.haomeiwen.com/subject/gjyanqtx.html