美文网首页
将TF-checkpoint 文件转换为 pytorch-che

将TF-checkpoint 文件转换为 pytorch-che

作者: BoringFantasy | 来源:发表于2020-09-23 20:38 被阅读0次
    1. 改代码将Bert的Tensorflow 检查点转换为 Pytorch的检查点,整理Transformers的代码得到,为了方便使用同时记录踩的坑。

    2. Tensorflow检查点文件解析。

    1. 包括以下3个文件
    model.ckpt.data-00000-of-00001
    model.ckpt.index
    model.ckpt.meta
    2. 其中model.ckpt为checkpoint的文件前缀,在命令行调用该代码提供 --tf_checkpoint_path 时需要同时提供checkpoint 前缀,例如 --tf_checkpoint_path model_checkpoint/model.ckpt
    
    1. 同时提供模型Config文件,名字通常为bert_config.json。

    2. 调用该代码命令行为:

    # 依赖自行下载
    # $checkpoint_path 为TF-checkpoint路径
    # $save_file 为pytorch-checkpoint 保存文件
    python3 convert_bert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path $checkpoint_path/model.ckpt --bert_config_file $checkpoint_path/bert_config.json --pytorch_dump_path $save_file
    
    1. 保存后得到一个 pytorch-checkpoint, 需要同 bert_config.json 和 vocab.txt在同一个文件夹,同时需要将Bert_config.json增加一个命名为config.json的文件,Transformers加载Pytorch模型时会自动调用,之后可以通过Transformers正常使用。

    2. 目前该代码已经保存至 https://github.com/YaoXinZhi/Convert-Bert-TF-checkpoint-to-Pytorch

    相关文章

      网友评论

          本文标题:将TF-checkpoint 文件转换为 pytorch-che

          本文链接:https://www.haomeiwen.com/subject/mfzoyktx.html