美文网首页
【NLP】BERT将预训练tensorflow模型转换为pyto

【NLP】BERT将预训练tensorflow模型转换为pyto

作者: DeepNLPLearner | 来源:发表于2020-11-13 10:48 被阅读0次

在Bert的预训练模型中,主流的模型都是以tensorflow的形势开源的。好在Transformers提供了一份可以转换的接口。

官方演示如下:

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12transformers-cli 
convert --model_type bert\
             --tf_checkpoint$BERT_BASE_DIR/bert_model.ckpt\
             --config$BERT_BASE_DIR/bert_config.json\
             --pytorch_dump_output$BERT_BASE_DIR/pytorch_model.bin

但是如何在windows的IDE中执行呢?
首先,需要安装transformers (可以挂国内清华、豆瓣源之类的加速)

pip install transformers

然后:

import transformers.convert_bert_original_tf_checkpoint_to_pytorch as con
con.convert_tf_checkpoint_to_pytorch(
    r'.\chinese_bert_chinese_wwm_L-12_H-768_A-12\publish\bert_model.ckpt',
    r'.\chinese_bert_chinese_wwm_L-12_H-768_A-12\publish\bert_config.json',
    r'.\chinese_bert_chinese_wwm_L-12_H-768_A-12\publish\pytorch_bert.bin'
)

convert_tf_checkpoint_to_pytorch中三个参数分别是:bert模型名称、config文件地址,输出的pytorch文件地址

通常的Bert预训练文件包含以下内容


image.png

其中 bert_model.chpt 本身是不存在的,我们传入的 bert_model.chpt 只要按照文件夹中模型的名称给出即可,不需要加 .index 后缀

正确执行后,你会看到:


image.png

对于 tensorflow2 的Bert模型,还有对应的转换接口:

import transformers.convert_bert_original_tf2_checkpoint_to_pytorch as con
con.convert_tf2_checkpoint_to_pytorch(……………………)

附上官方文档地址:https://huggingface.co/transformers/converting_tensorflow_models.html

相关文章

网友评论

      本文标题:【NLP】BERT将预训练tensorflow模型转换为pyto

      本文链接:https://www.haomeiwen.com/subject/jfjvbktx.html