美文网首页
加载训练好的BERT参数

加载训练好的BERT参数

作者: IT_小马哥 | 来源:发表于2020-11-29 15:27 被阅读0次

将预训练模型中的bert部分取出来加载上去

base_model = BaseModel(config)
base_model_dict = base_model.state_dict()

加载训练好的模型

pre_state_dict = torch.load(Bert_path)

new_state_dict = {k: v for k, v in pre_state_dict.items() if k in base_model_dict}
base_model_dict.update(new_state_dict)
base_model.load_state_dict(base_model_dict)

class BaseModel(nn.Module):
    def __init__(self, config):
        super(BaseModel, self).__init__()
        self.bert_model = BertModel.from_pretrained(config.bert_uncased_path, output_hidden_states=True,
                                                    output_attentions=True)
        for p in self.bert_model.parameters():
            p.requires_grad = False  # 预训练模型加载进来后全部设置为不更新参数,然后再后面加层

    def forward(self, input_ids, attention_mask):
        last_hidden_state = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[1]
        return last_hidden_state

新的方法

  • 可以把训练好的模型参数加载给新的模型。新的模型和老模型有部分相同。
                    model_orig = Bert_Classify(class_num)
                    model_orig.load_state_dict(torch.load(classifier_name))

                    model = NewModel(one_t, class_num)
                    model.load_state_dict(model_orig.state_dict(), strict=False)

相关文章

  • 加载训练好的BERT参数

    将预训练模型中的bert部分取出来加载上去 base_model = BaseModel(config)base_...

  • 神经网络识别手写优化(三)

    前言 本文是为了实现存储自己训练好的模型 结构和参数,以及加载训练好的模型进行预测。 代码 保存 加载

  • pytorch:Transformers入门(四)

    前面学习了Bert相关的类,每个类在实例化时,使用from_pretrained函数加载与训练的模型参数来初始化,...

  • PyTorch如何恢复指定权重

    1. 如何从已训练好的网络模型中提取指定层权重 2. 如何加载模型部分参数并更新 可以发现classifier.2...

  • 使用 Rasa Forms 构建上下文助手

    支持加载的权重: Google原版bert: https://github.com/google-research...

  • RoBERTa

    相比较bert,RoBERTa有以下几个改进: 模型参数:RoBERTa采用更大模型参数(1024 块 V100 ...

  • TensorFlow 调用预训练好的模型—— Python 实现

    1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值...

  • 封装图片预加载器

    1、图片预加载器插件的参数 参数参数表示的意义data图片地址each监听图片加载过程success所有图片加载完...

  • python训练好的模型保存与加载

    将训练好的模型保存下来,避免重复训练:以下使用pickle和joblib保存训练好的模型以及模型的加载 模型训练,...

  • Distilbert

    因为Bert本身参数量大,所以上线的过程中会碰到需求大空间和速度慢等问题。当前对Bert瘦身有三个思路,分别是Di...

网友评论

      本文标题:加载训练好的BERT参数

      本文链接:https://www.haomeiwen.com/subject/gqdkwktx.html