本文代码分析以哈佛大学NLP实验室实现的版本(https://github.com/harvardnlp/annotated-transformer)进行分析。
1.模型架构图
transformer.PNG2.模型核心流程
2.1 模型整体流程
#模型定义
def make_model(
src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1
):
c = copy.deepcopy
attn = MultiHeadedAttention(h, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
position = PositionalEncoding(d_model, dropout)
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
Generator(d_model, tgt_vocab),
)
...
return model
#模型使用
model.forward(
batch.src, batch.tgt, batch.src_mask, batch.tgt_mask
)
从上面可以看到,模型定义时主要时设置模型参数:第一个参数src_vocab是源语言词汇的个数,tgt_vocab是目标语言词汇的个数。N就是论文中说的Encoder和Decoder的个数,d_model就是embed的维度,d_ff是全连接的维度,h是Header的个数。
而模型在使用时,只需要传入4个参数:
src:待"翻译"语句,已经经过分词处理,并且每个词汇都表示为对应的数字,shape=(32,128),其中32为batch size,下同。
src_mask:因为对源语句进行了补齐(padding),此向量表示对应的padding mask,即对应的补齐位为0,shape = (32, 1, 128)
tgt:"翻译"的结果,已经经过分词处理,并且每个词汇都表示为对应的数字,shape = (32, 127)
tgt_mask:attention mask,是个下三角矩阵,shape = (32, 127, 127)
2.2 模型整体流程
transformer2.png2.3 Embedding
代码中Embedding的定义是这样的,其中用到了两个pytorch模块:nn.Sequential和nn.Embedding。
nn.Sequential(Embeddings(d_model, src_vocab), c(position))
nn.Embedding可以生成一个简单的查找表,将对应的向量用特定的维度表示,如下图:生成一个embedding表示,用维度为3的一维向量唯一的标识0-9的所有数字。
>>> embedding = nn.Embedding(10, 3)
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[
[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]
],
[
[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]
]])
nn.Sequential定义一个有序容器,参数是若干个module,其按照构造函数的顺序依次添加;当Sequential接收参数时,会先传给它的第一个module,其输出作为下一个module的输入,依次链式处理。
所以,整个流程是将src中的每个词汇用维度大小为d_model(512)的一维向量表示,处理后shape=(32, 128,512),然后进行位置编码(position)处理。
2.4 Encoder流程
transformer2-encoder.png1.输入会先经过Norm处理;
2.步骤【1】的输出与src_mask(1.2)同时传入Attention计算注意力(对应论文中的Multi-Head Attention);
3.步骤【2】的输出与步骤【1】的输入做Add运算(对应论文中的Add);
4.步骤【3】的输出再次进行Norm处理(对应论文中的Norm);
5.步骤【4】的输出经Feed Forward处理(对应论文中的Feed Forward);
6.步骤【5】的输出与步骤3的输出做Add运算,其输出作为输入重复步骤【1】(重复6次)。
网友评论