美文网首页
UIE实体抽取解读

UIE实体抽取解读

作者: lodestar | 来源:发表于2022-08-25 22:46 被阅读0次

    UIE(Universal Information Extraction):Yaojie Lu等人在ACL-2022中提出了通用信息抽取统一框架UIE。该框架实现了实体抽取、关系抽取、事件抽取、情感分析等任务的统一建模,并使得不同任务间具备良好的迁移和泛化能力。

    代码:https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/uie

    doccano是一个开源的文本标注工具,标注后的格式如下:

    #doccano_ext.json
    {"id": 1, "text": "昨天晚上十点加班打车回家58元", "relations": [], "entities": [{"id": 0, "start_offset": 0, "end_offset": 6, "label": "时间"}, {"id": 1, "start_offset": 11, "end_offset": 12, "label": "目的地"}, {"id": 2, "start_offset": 12, "end_offset": 14, "label": "费用"}]}
    {"id": 2, "text": "三月三号早上12点46加班,到公司54", "relations": [], "entities": [{"id": 3, "start_offset": 0, "end_offset": 11, "label": "时间"}, {"id": 4, "start_offset": 15, "end_offset": 17, "label": "目的地"}, {"id": 5, "start_offset": 17, "end_offset": 19, "label": "费用"}]}
    {"id": 3, "text": "8月31号十一点零四工作加班五十块钱", "relations": [], "entities": [{"id": 6, "start_offset": 0, "end_offset": 10, "label": "时间"}, {"id": 7, "start_offset": 14, "end_offset": 16, "label": "费用"}]}
    {"id": 4, "text": "5月17号晚上10点35分加班打车回家,36块五", "relations": [], "entities": [{"id": 8, "start_offset": 0, "end_offset": 13, "label": "时间"}, {"id": 1, "start_offset": 18, "end_offset": 19, "label": "目的地"}, {"id": 9, "start_offset": 20, "end_offset": 24, "label": "费用"}]}
    {"id": 5, "text": "2009年1月份通讯费一百元", "relations": [], "entities": [{"id": 10, "start_offset": 0, "end_offset": 7, "label": "时间"}, {"id": 11, "start_offset": 11, "end_offset": 13, "label": "费用"}]}
    
    python doccano.py \
        --doccano_file ./data/doccano_ext.json \
        --task_type ext \
        --save_dir ./data \
        --negative_ratio 0 \
        --splits 0.8 0.2 0
    

    解析后得到train.txt格式如下:

    {"content": "6月2日交通费123元", "result_list": [{"text": "6月2日", "start": 0, "end": 4}], "prompt": "时间"}
    {"content": "上海虹桥高铁到杭州时间是9月24日费用是73元", "result_list": [{"text": "上海虹桥", "start": 0, "end": 4}], "prompt": "出发地"}
    {"content": "从北京飞往上海出差飞机票费150元", "result_list": [{"text": "上海", "start": 5, "end": 7}], "prompt": "目的地"}
    

    将train.txt数据转化为input_ids

    def convert_example(example, tokenizer, max_seq_len):
        #提示学习
        encoded_inputs = tokenizer(text=[example["prompt"]],
                                   text_pair=[example["content"]],
                                   truncation=True,
                                   max_seq_len=max_seq_len,
                                   pad_to_max_seq_len=True,
                                   return_attention_mask=True,
                                   return_position_ids=True,
                                   return_dict=False,
                                   return_offsets_mapping=True)
        encoded_inputs = encoded_inputs[0]
        #offset_mapping 来映射,变换前和变化后的 id
        offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"]]
        bias = 0
        for index in range(1, len(offset_mapping)):
            mapping = offset_mapping[index]
            if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
                bias = offset_mapping[index - 1][1] + 1  # Includes [SEP] token
            if mapping[0] == 0 and mapping[1] == 0:
                continue
            offset_mapping[index][0] += bias
            offset_mapping[index][1] += bias
        start_ids = [0 for x in range(max_seq_len)]
        end_ids = [0 for x in range(max_seq_len)]
        for item in example["result_list"]:
            start = map_offset(item["start"] + bias, offset_mapping)
            end = map_offset(item["end"] - 1 + bias, offset_mapping)
            #start和end是input_ids中的位置
            start_ids[start] = 1.0
            end_ids[end] = 1.0
    
        tokenized_output = [
            encoded_inputs["input_ids"], encoded_inputs["token_type_ids"],
            encoded_inputs["position_ids"], encoded_inputs["attention_mask"],
            start_ids, end_ids
        ]
        tokenized_output = [np.array(x, dtype="int64") for x in tokenized_output]
        return tuple(tokenized_output)
    

    模型:

    class UIE(ErniePretrainedModel):
    
        def __init__(self, encoding_model):
            super(UIE, self).__init__()
            self.encoder = encoding_model
            hidden_size = self.encoder.config["hidden_size"]
            self.linear_start = paddle.nn.Linear(hidden_size, 1)
            #二分类
            self.linear_end = paddle.nn.Linear(hidden_size, 1)
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, input_ids, token_type_ids, pos_ids, att_mask):
            sequence_output, pooled_output = self.encoder(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                position_ids=pos_ids,
                attention_mask=att_mask)
            start_logits = self.linear_start(sequence_output)
            start_logits = paddle.squeeze(start_logits, -1)
            start_prob = self.sigmoid(start_logits)
            end_logits = self.linear_end(sequence_output)
            end_logits = paddle.squeeze(end_logits, -1)
            end_prob = self.sigmoid(end_logits)
            return start_prob, end_prob
    

    训练片段代码:

        model = UIE.from_pretrained(args.model)
        criterion = paddle.nn.BCELoss()    
        for epoch in range(1, args.num_epochs + 1):
            for batch in train_data_loader:
                input_ids, token_type_ids, att_mask, pos_ids, start_ids, end_ids = batch
                start_prob, end_prob = model(input_ids, token_type_ids, att_mask, pos_ids)
                #start_ids shape为:batch_size * max_seq_len
                #start_ids是记录开始位置
                start_ids = paddle.cast(start_ids, 'float32')
                end_ids = paddle.cast(end_ids, 'float32')
                loss_start = criterion(start_prob, start_ids)
                loss_end = criterion(end_prob, end_ids)
                loss = (loss_start + loss_end) / 2.0
    

    心得

    对于初学者,先理解整个流程,训练数据处理、损失函数选择、模型选择、评估函数,然后跑通代码,再断点debug代码,查看各个关键环节的shape变化。

    相关文章

      网友评论

          本文标题:UIE实体抽取解读

          本文链接:https://www.haomeiwen.com/subject/heuzgrtx.html