今天学习一个新的自然语言处理任务——文本纠错。文本纠错这个领域其实有细分成很多不同的类型:如下图所示
image.png
其中不同的问题需要采取不同的策略进行解决。传统的文本纠错一般会分为两个步骤
:错误检测和错误纠正。但是随着深度学习的发展,Seq2Seq的模型可以一步到位,端到端的解决文本纠错的问题。接下来笔者就来介绍一下,采用Bart模型进行文本纠错的实战过程。
BART模型简介
BART全称是: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension,
论文主要提出了一种 对抗噪声的方式来做语言模型的预训练,并将用此种方式预训练好的语言模型用于自然语言生成,翻译和理解。
如下图所示 BERT 使用了transformer的encode, GPT使用了transformer的decode,BART模型的整个架构其实是标准的transformer的架构,等于BERT架构+ GPT架构 + 对抗噪声预训练。
其实 采用完整transformer 结构的预训练语言模型不只有BART,还有google的T5以及微软的MASS,它们也都在自然语言生成式的下游任务中表现不错。而这三个模型之间最大的区别可能就是预训练任务的设计。接下来我们来看看BART的预训练任务是如何设计的。如下图所示,作者设计了预训练任务就是给语言模型输入各种各样的扰乱,然后让语言模型恢复原句,具体的扰动设计有如图所示的五个方式:
- 字符掩码
- 句子中的字符进行打乱
- 将句子进行倒序
- 字符删除
- 词级别的掩码
在这种对抗噪声的预训练之后,下游任务在此基础上finetune 就能取得非常好的效果。
这里笔者在huggface上下载了一个small 版本的中文BART.其下载详情页如下所示:
输入:中国的首都是[MASK]京 输出:中国的首都是北京
输入:作为电子[MASK]平台 输出:作为电子商务平台
模型下载地址:https://huggingface.co/uer/bart-chinese-6-960-cluecorpussmall/tree/main
数据介绍
数据集为SIGHAN+Wang271K中文纠错数据集,下载地址在https://github.com/shibing624/pycorrector/tree/master/pycorrector/t5这个页面之中。
其具体格式如下图所示:
[ {
"id":"-",
"original_text":"目前区次事件的细节还不清楚,伤亡人数也未确定。",
"wrong_ids":[
2],
"correct_text":"目前这次事件的细节还不清楚,伤亡人数也未确定。"},
{
"id":"-",
"original_text":"报导中并未说明出口量,但据引述药厂主管的话指出,每一种药物最大的庄度出口量都达到十二吨之谱。",
"wrong_ids":[
32
],
"correct_text":"报导中并未说明出口量,但据引述药厂主管的话指出,每一种药物最大的年度出口量都达到十二吨之谱。" }]
最终任务就是将错误的句子输入给模型,输出正确的句子。
image.png数据预处理
这一步实现数据的加载,数据的预处理,同时实现了一下模型预测函数。
import json
from dataclasses import dataclass, field
from typing import Optional
import os
import argparse
from transformers import AutoTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline
from transformers import HfArgumentParser, TrainingArguments, Trainer, set_seed
from datasets import load_dataset, Dataset
from loguru import logger
class CscDataset(object):
def __init__(self, file_path):
self.data = json.load(open(file_path, 'r', encoding='utf-8'))
def load(self):
data_list = []
for item in self.data:
data_list.append(item['original_text'] + '\t' + item['correct_text'])
if len(data_list)>10000:
break
return {'text': data_list}
import torch
def bart_correct(tokenizer, model, text: str, max_length: int = 128):
import numpy as np
inputs = tokenizer.encode(text, padding=True, max_length=32, truncation=True,
return_tensors='pt')
model.eval()
with torch.no_grad():
res = model(inputs).logits
res = np.argmax(res[0],axis=1)
res = res[1:-1]
decode_tokens = tokenizer.decode(res,skip_special_tokens=True).replace(' ', '')
return decode_tokens
d = CscDataset("./csc_sample/train.json")
data_dict = d.load()
train_dataset = Dataset.from_dict(data_dict, split='train')
d = CscDataset("./csc_sample/test.json")
data_dict = d.load()
valid_dataset = Dataset.from_dict(data_dict, split='test')
logger.info(train_dataset)
logger.info(valid_dataset)
def tokenize_dataset(tokenizer, dataset, max_len):
def convert_to_features(example_batch):
src_texts = []
trg_texts = []
for example in example_batch['text']:
terms = example.split('\t', 1)
src_texts.append(terms[0])
trg_texts.append(terms[1])
input_encodings = tokenizer.batch_encode_plus(
src_texts,
truncation=True,
padding='max_length',
max_length=max_len,
)
target_encodings = tokenizer.batch_encode_plus(
trg_texts,
truncation=True,
padding='max_length',
max_length=max_len,
)
encodings = {
'input_ids': input_encodings['input_ids'],
'attention_mask': input_encodings['attention_mask'],
'target_ids': target_encodings['input_ids'],
'target_attention_mask': target_encodings['attention_mask']
}
return encodings
dataset = dataset.map(convert_to_features, batched=True)
# Set the tensor type and the columns which the dataset should return
columns = ['input_ids', 'target_ids', 'attention_mask', 'target_attention_mask']
dataset.with_format(type='torch', columns=columns)
# Rename columns to the names that the forward method of the selected
# model expects
dataset = dataset.rename_column('target_ids', 'labels')
dataset = dataset.rename_column('target_attention_mask', 'decoder_attention_mask')
dataset = dataset.remove_columns(['text'])
return dataset
train_data = tokenize_dataset(tokenizer, train_dataset,128)
valid_data = tokenize_dataset(tokenizer, valid_dataset,128)
bart_correct(tokenizer, model,"中国的首都是[MASK]京",32)
模型加载
这里加载从huggface 上下载的BART模型,这里我将文件夹命名成了bart.
tokenizer = AutoTokenizer.from_pretrained("./bart/")
model = BartForConditionalGeneration.from_pretrained("./bart/")
测试了一下模型的纠错能力,发现对【MASK】这个字符效果不错,这也是得益于预训练任务有字符掩码复原这个任务。但是错字纠正却做得不好。接下来我们来再文本纠错的数据集上进行finetune一下。看看效果如何。
image.png
模型训练finetune
设置好训练参数,进行模型finetune,这里由于笔者的硬件比较简陋(CPU笔记本),只在10000个样本上训练了2轮,耗时2个小时。
training_args = TrainingArguments(
output_dir='./results', # output directory 结果输出地址
num_train_epochs=2, # total # of training epochs 训练总批次
per_device_train_batch_size=32, # batch size per device during training 训练批大小
per_device_eval_batch_size=32, # batch size for evaluation 评估批大小
logging_dir='./logs/rn_log', # directory for storing logs 日志存储位置
learning_rate=1e-4, # 学习率
save_steps=False,# 不保存检查点
logging_steps=2
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=valid_data,
)
trainer.train()
##模型保存
model.save_pretrained("result_bart/")
这里可以看到loss正在下降,18个step后loss就从6降到了0.15。
image.png
模型预测
接下来我们采用训练后的模型再试试纠错效果。
new_model = BartForConditionalGeneration.from_pretrained("./result_bart/")
测试了4个错误句子,只训练了2轮,就具备了一定的纠错能力。
image.png
结语
在这种句子生成任务预训练的模型上进行文本对纠错微调,确实有非常好的效果,这样得益于预训练任务和下游任务非常相似。所以,以后面对文本生成的任务,都可以采用对BART,T5等进行微调,也行就能取得不错的效果。
参考
https://github.com/shibing624/pycorrector/tree/develop/pycorrector
https://blog.csdn.net/kittyzc/article/details/124926125
https://huggingface.co/uer/bart-chinese-6-960-cluecorpussmall
网友评论