PyTorch中Transformer模型的使用

作者: dalalaa | 来源:发表于2020-02-14 16:04 被阅读0次

    PyTorch最近版本更新很快,1.2/1.3/1.4几乎是连着出,其中:
    1.3/1.4版本主要是新增并完善了PyTorchMobile移动端部署模块和模型量化模块。
    而1.2版中一个重要的更新就是把加入了NLP领域中炙手可热的Transformer模型,这里记录一下PyTorchTransformer模型的用法(代码写于1.2版本,没有在1.3/1.4版本测试)。

    简介


    也许是为了更方便地搭建BertGPT-2之类的NLP模型,PyTorchTransformer相关的模型分为nn.TransformerEncoderLayernn.TransformerDecoderLayernn.LayerNorm等几个部分。搭建模型的时候不一定都会用到, 比如fastai中的Transformer模型就只用到了encoder部分,没有用到decoder

    至于WordEmbeddingPositionEncoding两个部分需要自己另外实现。

    WordEmbedding可以直接使用PyTorch自带的nn.Embedding层。

    PositionEncoding层的花样就多了,不同的模型下面有不同的PositionEncoding,比如Transformer的原始论文Attention is all you need中使用的是无参数的PositionEncodingBert中使用的是带有学习参数的PositionEncoding

    在本文中介绍的是参考Transformer原始论文实现的Sequence2sequence形式的Transformer模型。

    Sequence2sequence形式的Transformer模型搭建:


    无可学习参数的PositionEncoding层

    无参数的PositionEncoding计算速度快,还可以减小整个模型的尺寸,据说在有些任务中,效果与有参数的接近。

    class PositionalEncoding(nn.Module):
        def __init__(self, d_model,dropout=0.1, max_len=5000):
            super(PositionalEncoding, self).__init__()
            self.dropout = nn.Dropout(p=dropout)
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0).transpose(0, 1)
            self.register_buffer('pe', pe)
    
        def forward(self, x):
            x = x + self.pe[:x.size(0), :]
            return self.dropout(x)
    

    有可学习参数的PositionEncoding层

    我曾在一个序列预测任务(非NLP)里面对比过两种PositionEncoding层,发现带有参数的PositionEncoding层效果明显比没有参数的PositionEncoding要好。

    带参数的PositionEncoding层的定义更为简单,直接继承一个nn.Embedding,再续上一个dropout就可以了。因为nn.Embedding中包含了一个可以按索引取向量的权重矩阵weight

    class LearnedPositionEncoding(nn.Embedding):
        def __init__(self,d_model, dropout = 0.1,max_len = 5000):
            super().__init__(max_len, d_model)
            self.dropout = nn.Dropout(p = dropout)
         
        def forward(self, x):
            weight = self.weight.data.unsqueeze(1)
            x = x + weight[:x.size(0),:]
            return self.dropout(x)
    

    Sequence2sequence模型

    class S2sTransformer(nn.Module):
    
        def __init__(self,vocab_size,position_enc,d_model = 512,nhead = 8,num_encoder_layers=6,
                     num_decoder_layers=6,dim_feedforward=2048,dropout=0.1):
            super(S2sTransformer,self).__init__()
    
            # Preprocess
            self.embedding = nn.Embedding(vocab_size,d_model)
            self.pos_encoder_src = position_enc(d_model=512)
            # tgt
            self.pos_encoder_tgt = position_enc(d_model=512)
    
            # Encoder
            encoder_layer = nn.TransformerEncoderLayer(d_model,nhead,dim_feedforward,dropout)
            encoder_norm = nn.LayerNorm(d_model)
            self.encoder = nn.TransformerEncoder(encoder_layer,num_encoder_layers,encoder_norm)
    
            # Decoder
            decoder_layer = nn.TransformerDecoderLayer(d_model,nhead,dim_feedforward,dropout)
            decoder_norm = nn.LayerNorm(d_model)
            self.decoder = nn.TransformerDecoder(decoder_layer,num_decoder_layers,decoder_norm)
            self.output_layer = nn.Linear(d_model,vocab_size)
    
            self._reset_parameters()
            self.d_model = d_model
            self.nhead = nhead
    
    
        def forward(self, src,tgt,src_mask = None,tgt_mask = None,
                    memory_mask = None,src_key_padding_mask = None,
                    tgt_key_padding_mask = None,memory_key_padding_mask = None):
    
            # word embedding
            src = self.embedding(src)
            tgt = self.embedding(tgt)
    
            # shape check
            if src.size(1) != tgt.size(1):
                raise RuntimeError("the batch number of src and tgt must be equal")
            if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
                raise RuntimeError("the feature number of src and tgt must be equal to d_model")
    
            # position encoding
            src = self.pos_encoder_src(src)
            tgt = self.pos_encoder_tgt(tgt)
    
            memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
            output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=memory_key_padding_mask)
            output = self.output_layer(output)
            # return output
            return softmax(output,dim = 2)
    
    
        def generate_square_subsequent_mask(self, sz):
            r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
                Unmasked positions are filled with float(0.0).
            """
            mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
            mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
            return mask
    
        def _reset_parameters(self):
            r"""Initiate parameters in the transformer model."""
    
            for p in self.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
    

    模型搭建好了之后,就可以按照Sequence2sequence的训练方式进行训练了, 唯一需要注意的就是Transformerforward过程是并行的,与基于RNNSequence2sequence模型稍有不同。

    训练过程可以参考PyTorch官网提供的chatbot的教程, 如果有问题可以留言、私信或者加群交流:

    机器学习----菜鸡互啄群聊二维码.png

    相关文章

      网友评论

        本文标题:PyTorch中Transformer模型的使用

        本文链接:https://www.haomeiwen.com/subject/jyvaxhtx.html