在Bert的预训练模型中,主流的模型都是以tensorflow的形势开源的。好在Transformers提供了一份可以转换的接口。
官方演示如下:
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12transformers-cli
convert --model_type bert\
--tf_checkpoint$BERT_BASE_DIR/bert_model.ckpt\
--config$BERT_BASE_DIR/bert_config.json\
--pytorch_dump_output$BERT_BASE_DIR/pytorch_model.bin
但是如何在windows的IDE中执行呢?
首先,需要安装transformers (可以挂国内清华、豆瓣源之类的加速)
pip install transformers
然后:
import transformers.convert_bert_original_tf_checkpoint_to_pytorch as con
con.convert_tf_checkpoint_to_pytorch(
r'.\chinese_bert_chinese_wwm_L-12_H-768_A-12\publish\bert_model.ckpt',
r'.\chinese_bert_chinese_wwm_L-12_H-768_A-12\publish\bert_config.json',
r'.\chinese_bert_chinese_wwm_L-12_H-768_A-12\publish\pytorch_bert.bin'
)
convert_tf_checkpoint_to_pytorch中三个参数分别是:bert模型名称、config文件地址,输出的pytorch文件地址
通常的Bert预训练文件包含以下内容
image.png
其中 bert_model.chpt 本身是不存在的,我们传入的 bert_model.chpt 只要按照文件夹中模型的名称给出即可,不需要加 .index 后缀
正确执行后,你会看到:
image.png
对于 tensorflow2 的Bert模型,还有对应的转换接口:
import transformers.convert_bert_original_tf2_checkpoint_to_pytorch as con
con.convert_tf2_checkpoint_to_pytorch(……………………)
附上官方文档地址:https://huggingface.co/transformers/converting_tensorflow_models.html
网友评论