美文网首页
Fine-tune mT5模型

Fine-tune mT5模型

作者: 乘瓠散人 | 来源:发表于2022-04-21 22:31 被阅读0次

    我们之前介绍过,Google的大规模预训练语言模型T5(Text-to-Text Transfer Transformer)是仅仅基于英文语料训练的,因此无法应用在中文语料上。之后,Google又相继推出了支持多语言的版本——mT5(Multilingual T5),没错,就是“走自己的路,让别人无路可走”的感觉,不过因此也就方便我们(站在巨人的肩膀上)使用中文数据了。

    本文会介绍调用mT5模型的关键代码,主要基于Huggingface transformer库 mT5 来实现,但是中文相关数据集需要自己提供。

    mT5: A massively multilingual pre-trained text-to-text transformer

    Pytorch-Lightning

    为了突出代码的核心逻辑,我们会用到pytorch-lightning这个库,先简单介绍下。Pytorch-Lightning是在Pytorch基础上封装的一个高阶库,让用户能够专注于核心代码的构建,摆脱一些繁琐的细节,从而使得实验研究能够更加轻量高效地进行。可以参考Lightning in 15 minutes 快速浏览pytorch-lightning精简原始pytorch代码的过程。

    pytorch-lightning

    使用mT5模型

    1. 环境配置

    !pip install --quiet sentencepiece==0.1.96 # 注意安装在transformers之前
    !pip install --quiet transformers==4.18.0
    !pip install --quiet pytorch-lightning==1.6.1
    

    2. 加载预训练模型

    from transformers import MT5Tokenizer, MT5ForConditionalGeneration
    import torch
    
    tokenizer = MT5Tokenizer.from_pretrained('google/mt5-small')
    model = MT5ForConditionalGeneration.from_pretrained('google/mt5-small')
    
    # the following 2 hyperparameters are task-specific
    max_source_length = 128 # 512
    max_target_length = 128
    

    3. tokenize sentences

    这里打印一些例子,方便我们理解对输入句子的分词处理。

    # Suppose we have the following 2 training examples:
    input_sequence_1 = "Welcome to Beijing"
    output_sequence_1 = "欢迎来到北京"
    
    input_sequence_2 = "HuggingFace is a company"
    output_sequence_2 = "拥抱脸是一家公司"
    
    # encode the inputs
    task_prefix = "translate English to Chinese: "
    input_sequences = [input_sequence_1, input_sequence_2]
    
    input_tokens_1 = tokenizer.tokenize(task_prefix + input_sequence_1)
    print('input_tokens_1:', input_tokens_1)
    output_tokens_1 = tokenizer.tokenize(output_sequence_1)
    print('output_tokens_1:', output_tokens_1)
    
    input_tokens_2 = tokenizer.tokenize(task_prefix + input_sequence_2)
    print('input_tokens 2:', input_tokens_2)
    output_tokens_2 = tokenizer.tokenize(output_sequence_2)
    print('output_tokens_2:', output_tokens_2)
    
    encoding = tokenizer(
        [task_prefix + sequence for sequence in input_sequences],
        padding="longest", # pad to the longest sequence in the batch
        max_length=max_source_length,
        truncation=True,
        return_tensors="pt",
    )
    
    print('encoding', encoding)
    input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
    
    # encode the targets
    target_encoding = tokenizer(
        [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
    )
    labels = target_encoding.input_ids
    
    print('Labels:', labels)
    
    

    得到的输入结果为:

    output

    会发现句子的input_ids 末尾会多一个token_id=1,其实是对应添加到句子末尾的token</s>

    4. 准备模型

    class MT5FineTuner(pl.LightningModule):
        def __init__(self, hparams, mt5model, mt5tokenizer):
            super(MT5FineTuner, self).__init__()
            # self.hparams = hparams
            self.save_hyperparameters(hparams)
            self.model = mt5model
            self.tokenizer = mt5tokenizer
    
        def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None,
                    lm_labels=None):
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_attention_mask=decoder_attention_mask,
                labels=lm_labels,
            )
    
            return outputs
    
        def training_step(self, batch, batch_idx):
            outputs = self.forward(
                input_ids=batch["source_ids"],
                attention_mask=batch["source_mask"],
                decoder_input_ids=batch["target_ids"],
                decoder_attention_mask=batch['target_mask'],
                lm_labels=batch['labels']
            )
    
            loss = outputs[0]
            self.log('train_loss', loss)
            return loss
    
        def validation_step(self, batch, batch_idx):
            outputs = self.forward(
                input_ids=batch["source_ids"],
                attention_mask=batch["source_mask"],
                decoder_input_ids=batch["target_ids"],
                decoder_attention_mask=batch['target_mask'],
                lm_labels=batch['labels']
            )
    
            loss = outputs[0]
            self.log("val_loss", loss)
            return loss
    
        def train_dataloader(self):
            return DataLoader(train_dataset, batch_size=self.hparams.batch_size, num_workers=4)
    
        def val_dataloader(self):
            return DataLoader(validation_dataset, batch_size=self.hparams.batch_size, num_workers=4)
    
        def configure_optimizers(self):
            optimizer = AdamW(self.parameters(), lr=3e-4, eps=1e-8)
            return optimizer
    
    

    5. 训练模型

    import pytorch_lightning as pl
    
    args_dict = dict(
        batch_size=1,
    )
    args = argparse.Namespace(**args_dict)
    
    model = MT5FineTuner(args, mt5_model, mt5_tokenizer)
    
    trainer = pl.Trainer(max_epochs=5, gpus=1, log_every_n_steps=1)
    trainer.fit(model)
    
    

    5. 测试模型

    test_sent = 'translate: The sailor was happy.'
    test_tokenized = mt5_tokenizer(test_sent, return_tensors="pt")
    
    test_input_ids = test_tokenized["input_ids"]
    test_attention_mask = test_tokenized["attention_mask"]
    
    model.model.eval()
    beam_outputs = model.model.generate(
        input_ids=test_input_ids, 
        attention_mask=test_attention_mask
    )
    sent = mt5_tokenizer.decode(beam_outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    print(sent)
    
    

    完整代码逻辑参考 Minimalistic training of T5 transformer with Pytorch Lightning and HuggingFace.ipynb

    Pytorch Lightning 完全攻略 - 知乎
    Huggingface-mT5教程

    相关文章

      网友评论

          本文标题:Fine-tune mT5模型

          本文链接:https://www.haomeiwen.com/subject/zwqvertx.html