Pytorch学习记录-torchtext和Pytorch的实例4
0. PyTorch Seq2Seq项目介绍
在完成基本的torchtext之后,找到了这个教程,《基于Pytorch和torchtext来理解和实现seq2seq模型》。
这个项目主要包括了6个子项目
使用神经网络训练Seq2Seq使用RNN encoder-decoder训练短语表示用于统计机器翻译使用共同学习完成NMT的堆砌和翻译打包填充序列、掩码和推理卷积Seq2Seq- Transformer
6. Transformer
OK,来到最后一章,Transformer,又回到这个模型啦,绕不开的,依旧没有讲解,只能看看代码。
来源不用说了,《Attention is all you need》。Transformer在之前复习了多次,这次也一样,不知道教程会如何实现,反正之前学得挺痛苦的。
6.1 准备数据
这里使用了一个新的数据集TranslationDataset,机器翻译数据集是 TranslationDataset 类的子类。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchtext
#机器翻译数据集是 TranslationDataset 类的子类。
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator
import spacy
import random
import math
import os
import time
SEED=1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic=True
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
def tokenize_de(text):
return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
SRC=Field(tokenize=tokenize_de,
init_token='<sos>',
eos_token='<eos>',
lower=True,
batch_first=True)
TRG=Field(tokenize=tokenize_en,
init_token='<sos>',
eos_token='<eos>',
lower=True,
batch_first=True)
train_data,valid_data,test_data=Multi30k.splits(
exts=('.de','.en'),
fields=(SRC, TRG)
)
SRC.build_vocab(train_data,min_freq=2)
TRG.build_vocab(train_data,min_freq=2)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE=128
train_iterator, valid_iterator, test_iterator=BucketIterator.splits(
(train_data,valid_data,test_data),
batch_size=BATCH_SIZE,
device=device
)
cuda
6.2 构建模型

6.2.1 encoder和decoder
6.2.2 使用多种attention机制(multi-head context-attention、multi-head self-attention)
6.2.2.1 multi head self-attention
6.2.2.2 multi head context-attention
6.2.2.3 如何实现Attention?
6.2.2.4 如何实现multi-heads attention?
6.2.3 使用Layer-Normalization机制
6.2.4.1 padding mask
6.2.4.2 sequence mask
6.2.5 使用残差residual connection
6.2.6 使用Positional-encoding
6.2.7 Position-wise Feed-Forward network
6.3.1 Encoder
照例是Encoder部分,包括了Encoder,EncoderLayer,SelfAttention,PositionwiseFeedforward四个部分
class Encoder(nn.Module):
def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention, positionwise_feedforward, dropout, device):
super(Encoder, self).__init__()
self.input_dim=input_dim
self.hid_dim=hid_dim
self.n_layers=n_layers
self.n_heads=n_heads
self.pf_dim=pf_dim
self.encoder_layer=encoder_layer
self.self_attention=self_attention
self.positionwise_feedforward=positionwise_feedforward
self.dropout=dropout
self.device=device
self.tok_embedding=nn.Embedding(input_dim, hid_dim)
self.pos_embedding=nn.Embedding(1000,hid_dim)
self.layers=nn.ModuleList([encoder_layer(hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device) for _ in range(n_layers)])
self.do=nn.Dropout(dropout)
self.scale=torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
def forward(self, src, src_mask):
#src = [batch size, src sent len]
#src_mask = [batch size, src sent len]
pos=torch.arange(0,src.shape[1]).unsqueeze(0).repeat(src.shape[0],1).to(self.device)
src=self.do((self.tok_embedding(src)*self.scale)+self.pos_embedding(pos))
#src = [batch size, src sent len, hid dim]
for layer in self.layers:
src=layer(src, src_mask)
return src
class EncoderLayer(nn.Module):
def __init__(self, hid_dim, n_heads, pf_dim, self_attention, postionwise_feedforward,dropout,device):
super(EncoderLayer,self).__init__()
self.ln=nn.LayerNorm(hid_dim)
self.sa=self_attention(hid_dim,n_heads,dropout,device)
self.pf=postionwise_feedforward(hid_dim, pf_dim,dropout)
self.do=nn.Dropout(dropout)
def forward(self, src, src_mask):
#src = [batch size, src sent len, hid dim]
#src_mask = [batch size, src sent len]
src=self.ln(src+self.do(self.sa(src,src,src,src_mask)))
src=self.ln(src+self.do(self.pf(src)))
return src
class SelfAttention(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super(SelfAttention,self).__init__()
self.hid_dim=hid_dim
self.n_heads=n_heads
assert hid_dim%n_heads==0
self.w_q=nn.Linear(hid_dim,hid_dim)
self.w_k=nn.Linear(hid_dim, hid_dim)
self.w_v=nn.Linear(hid_dim, hid_dim)
self.fc=nn.Linear(hid_dim,hid_dim)
self.do=nn.Dropout(dropout)
self.scale=torch.sqrt(torch.FloatTensor([hid_dim//n_heads])).to(device)
def forward(self, query, key, value, mask=None):
bsz=query.shape[0]
#query = key = value [batch size, sent len, hid dim]
Q=self.w_q(query)
K=self.w_k(key)
V=self.w_v(value)
#Q, K, V = [batch size, sent len, hid dim]
Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
K = K.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
V = V.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
#Q, K, V = [batch size, n heads, sent len, hid dim // n heads]
# 实现attentionQ*K^T/D
energy=torch.matmul(Q,K.permute(0,1,3,2))/self.scale
#energy = [batch size, n heads, sent len, sent len]
if mask is not None:
energy=energy.masked_fill(mask==0, -1e10)
# 实现softmax部分
attention=self.do(F.softmax(energy, dim=-1))
#attention = [batch size, n heads, sent len, sent len]
x=torch.matmul(attention,V)
#x = [batch size, n heads, sent len, hid dim // n heads]
x=x.permute(0,2,1,3).contiguous()
#x = [batch size, sent len, n heads, hid dim // n heads]
x=x.view(bsz, -1, self.n_heads*(self.hid_dim//self.n_heads))
#x = [batch size, src sent len, hid dim]
x=self.fc(x)
return x
class PositionwiseFeedforward(nn.Module):
def __init__(self, hid_dim, pf_dim, dropout):
super(PositionwiseFeedforward,self).__init__()
self.hid_dim=hid_dim
self.pf_dim=pf_dim
self.fc_1=nn.Conv1d(hid_dim,pf_dim,1)
self.fc_2=nn.Conv1d(pf_dim, hid_dim, 1)
self.do=nn.Dropout(dropout)
def forward(self,x):
#x = [batch size, sent len, hid dim]
x = x.permute(0, 2, 1)
#x = [batch size, hid dim, sent len]
x = self.do(F.relu(self.fc_1(x)))
#x = [batch size, ff dim, sent len]
x = self.fc_2(x)
#x = [batch size, hid dim, sent len]
x = x.permute(0, 2, 1)
#x = [batch size, sent len, hid dim]
return x
6.3.2 Decoder
Decoder部分包括Decoder,DecoderLayer两个部分
class Decoder(nn.Module):
def __init__(self, output_dim, hid_dim,n_layers,n_heads,pf_dim,decoder_layer,self_attention,positionwise_feedforward,dropout,device):
super(Decoder,self).__init__()
self.output_dim=output_dim
self.hid_dim=hid_dim
self.n_layers=n_layers
self.n_heads = n_heads
self.pf_dim = pf_dim
self.decoder_layer = decoder_layer
self.self_attention = self_attention
self.positionwise_feedforward = positionwise_feedforward
self.dropout = dropout
self.device = device
self.tok_embedding=nn.Embedding(output_dim, hid_dim)
self.pos_embedding=nn.Embedding(1000,hid_dim)
self.layers=nn.ModuleList([decoder_layer(hid_dim,n_heads,pf_dim,self_attention,positionwise_feedforward,dropout,device) for _ in range(n_layers)])
self.fc=nn.Linear(hid_dim, output_dim)
self.do=nn.Dropout(dropout)
self.scale=torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
def forward(self, trg, src, trg_mask, src_mask):
#trg = [batch_size, trg sent len]
#src = [batch_size, src sent len]
#trg_mask = [batch size, trg sent len]
#src_mask = [batch size, src sent len]
pos=torch.arange(0, trg.shape[1]).unsqueeze(0).repeat(trg.shape[0], 1).to(self.device)
trg=self.do((self.tok_embedding(trg)*self.scale)+self.pos_embedding(pos))
for layer in self.layers:
trg=layer(trg,src,trg_mask,src_mask)
return self.fc(trg)
class DecoderLayer(nn.Module):
def __init__(self, hid_dim, n_heads, pf_dim, self_attention, positionwise_feedforward, dropout, device):
super().__init__()
self.ln = nn.LayerNorm(hid_dim)
self.sa = self_attention(hid_dim, n_heads, dropout, device)
self.ea = self_attention(hid_dim, n_heads, dropout, device)
self.pf = positionwise_feedforward(hid_dim, pf_dim, dropout)
self.do = nn.Dropout(dropout)
def forward(self, trg, src, trg_mask, src_mask):
#trg = [batch size, trg sent len, hid dim]
#src = [batch size, src sent len, hid dim]
#trg_mask = [batch size, trg sent len]
#src_mask = [batch size, src sent len]
trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask)))
trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask)))
trg = self.ln(trg + self.do(self.pf(trg)))
return trg
6.3.3 模型整合
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, pad_idx, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.pad_idx = pad_idx
self.device = device
def make_masks(self, src, trg):
#src = [batch size, src sent len]
#trg = [batch size, trg sent len]
src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
trg_pad_mask = (trg != self.pad_idx).unsqueeze(1).unsqueeze(3)
trg_len = trg.shape[1]
trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), dtype=torch.uint8, device=self.device))
trg_mask = trg_pad_mask & trg_sub_mask
return src_mask, trg_mask
def forward(self, src, trg):
#src = [batch size, src sent len]
#trg = [batch size, trg sent len]
src_mask, trg_mask = self.make_masks(src, trg)
enc_src = self.encoder(src, src_mask)
#enc_src = [batch size, src sent len, hid dim]
out = self.decoder(trg, enc_src, trg_mask, src_mask)
#out = [batch size, trg sent len, output dim]
return out
input_dim=len(SRC.vocab)
hid_dim=512
n_layers=6
n_heads=8
pf_dim=2048
dropout=0.1
enc=Encoder(input_dim,hid_dim,n_layers,n_heads,pf_dim,EncoderLayer,SelfAttention,PositionwiseFeedforward,dropout,device)
output_dim=len(TRG.vocab)
hid_dim=512
n_layers=6
n_heads=8
pf_dim=2048
dropout=0.1
dec=Decoder(output_dim,hid_dim, n_layers, n_heads, pf_dim, DecoderLayer, SelfAttention, PositionwiseFeedforward, dropout, device)
pad_idx=SRC.vocab.stoi['<pad>']
model=Seq2Seq(enc,dec,pad_idx,device).to(device)
model
这部分是模型结构输出,可以看到Encoder和Decoder的结构,建议和前面的图进行一次对比。

Seq2Seq(
(encoder): Encoder(
(tok_embedding): Embedding(7855, 512)
(pos_embedding): Embedding(1000, 512)
(layers): ModuleList(
(0): EncoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(1): EncoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(2): EncoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(3): EncoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(4): EncoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(5): EncoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
)
(do): Dropout(p=0.1)
)
(decoder): Decoder(
(tok_embedding): Embedding(5893, 512)
(pos_embedding): Embedding(1000, 512)
(layers): ModuleList(
(0): DecoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(ea): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(1): DecoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(ea): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(2): DecoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(ea): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(3): DecoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(ea): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(4): DecoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(ea): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
(5): DecoderLayer(
(ln): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
(sa): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(ea): SelfAttention(
(w_q): Linear(in_features=512, out_features=512, bias=True)
(w_k): Linear(in_features=512, out_features=512, bias=True)
(w_v): Linear(in_features=512, out_features=512, bias=True)
(fc): Linear(in_features=512, out_features=512, bias=True)
(do): Dropout(p=0.1)
)
(pf): PositionwiseFeedforward(
(fc_1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
(fc_2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
(do): Dropout(p=0.1)
)
(do): Dropout(p=0.1)
)
)
(fc): Linear(in_features=512, out_features=5893, bias=True)
(do): Dropout(p=0.1)
)
)
6.4.1 参数设置
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 55,206,149 trainable parameters
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * \
(self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))
optimizer = NoamOpt(hid_dim, 1, 2000,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
6.4.2 模型训练
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
optimizer.optimizer.zero_grad()
output = model(src, trg[:,:-1])
#output = [batch size, trg sent len - 1, output dim]
#trg = [batch size, trg sent len]
output = output.contiguous().view(-1, output.shape[-1])
trg = trg[:,1:].contiguous().view(-1)
#output = [batch size * trg sent len - 1, output dim]
#trg = [batch size * trg sent len - 1]
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
output = model(src, trg[:,:-1])
#output = [batch size, trg sent len - 1, output dim]
#trg = [batch size, trg sent len]
output = output.contiguous().view(-1, output.shape[-1])
trg = trg[:,1:].contiguous().view(-1)
#output = [batch size * trg sent len - 1, output dim]
#trg = [batch size * trg sent len - 1]
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
N_EPOCHS = 10
CLIP = 1
SAVE_DIR = 'models'
MODEL_SAVE_PATH = os.path.join(SAVE_DIR, 'transformer-seq2seq.pt')
best_valid_loss = float('inf')
if not os.path.isdir(f'{SAVE_DIR}'):
os.makedirs(f'{SAVE_DIR}')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f'| Epoch: {epoch+1:03} | Time: {epoch_mins}m {epoch_secs}s| Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f} | Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f} |')
大概跑了一下结果,真不知道教程用什么硬件运行的,才40多秒的速度……
| Epoch: 001 | Time: 1m 44s| Train Loss: 5.924 | Train PPL: 373.732 | Val. Loss: 4.119 | Val. PPL: 61.478 |
| Epoch: 002 | Time: 1m 48s| Train Loss: 3.778 | Train PPL: 43.709 | Val. Loss: 3.177 | Val. PPL: 23.976 |
| Epoch: 003 | Time: 1m 48s| Train Loss: 3.133 | Train PPL: 22.939 | Val. Loss: 2.812 | Val. PPL: 16.645 |
| Epoch: 004 | Time: 1m 48s| Train Loss: 2.763 | Train PPL: 15.846 | Val. Loss: 2.611 | Val. PPL: 13.615 |
| Epoch: 005 | Time: 1m 47s| Train Loss: 2.500 | Train PPL: 12.183 | Val. Loss: 2.421 | Val. PPL: 11.260 |
| Epoch: 006 | Time: 1m 48s| Train Loss: 2.310 | Train PPL: 10.073 | Val. Loss: 2.334 | Val. PPL: 10.318 |
网友评论