美文网首页
加载训练好的BERT参数

加载训练好的BERT参数

作者: IT_小马哥 | 来源:发表于2020-11-29 15:27 被阅读0次

    将预训练模型中的bert部分取出来加载上去

    base_model = BaseModel(config)
    base_model_dict = base_model.state_dict()

    加载训练好的模型

    pre_state_dict = torch.load(Bert_path)

    new_state_dict = {k: v for k, v in pre_state_dict.items() if k in base_model_dict}
    base_model_dict.update(new_state_dict)
    base_model.load_state_dict(base_model_dict)

    class BaseModel(nn.Module):
        def __init__(self, config):
            super(BaseModel, self).__init__()
            self.bert_model = BertModel.from_pretrained(config.bert_uncased_path, output_hidden_states=True,
                                                        output_attentions=True)
            for p in self.bert_model.parameters():
                p.requires_grad = False  # 预训练模型加载进来后全部设置为不更新参数,然后再后面加层
    
        def forward(self, input_ids, attention_mask):
            last_hidden_state = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[1]
            return last_hidden_state
    
    

    新的方法

    • 可以把训练好的模型参数加载给新的模型。新的模型和老模型有部分相同。
                        model_orig = Bert_Classify(class_num)
                        model_orig.load_state_dict(torch.load(classifier_name))
    
                        model = NewModel(one_t, class_num)
                        model.load_state_dict(model_orig.state_dict(), strict=False)
    
    

    相关文章

      网友评论

          本文标题:加载训练好的BERT参数

          本文链接:https://www.haomeiwen.com/subject/gqdkwktx.html