美文网首页
Bert 转Onnx 模型时候遇到的坑

Bert 转Onnx 模型时候遇到的坑

作者: NazgulSun | 来源:发表于2023-03-11 15:55 被阅读0次

Bert 转Onnx 背景

自己使用的是预训练模型作的finetuning。
然后保存了pt 文件,
失败的方法:
直接 model = torch.load(file.pt)
torch.onnx.export(
model,
(b_input_ids,token_types,b_input_mask),
"model.onnx",
input_names=['input_ids','token_type_ids', 'attention_mask'],
output_names=["logits"],
dynamic_axes= {
'input_ids': {0: 'batch_size'},
'token_type_ids': {0: 'batch_size'},
'attention_mask': {0: 'batch_size'},
'logits': {0: 'batch_size'}
},
do_constant_folding=True,
opset_version= 11)
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model,True)

整个方法没有问题,然后就是infer的时候, 结果和pytorch 结果不一样;
查了整个网上的issue 回答,没有明确的答案,只是判断为 export 导出的结果和原始的结构不一样。

使用 transformer 官方的方法

可参考
https://levelup.gitconnected.com/multilingual-text-classification-with-transformers-2147fe179c6b
先用checkpoint的方式保存 model 和 tokenizer
model = torch.load(model_path)
model.save_pretrained("")
tokenizer = BertTokenizerFast.from_pretrained("swtx/ernie-3.0-base-chinese")
tokenizer.save_pretrained(
")

然后使用观法的命令:
python -m transformers.onnx --model=hard-v6/ --feature=sequence-classification --framework=pt onnx_output/
大家可以参考 transformers.onnx 文档,
--feature 与你使用的模型相关,需要查找文档确定,由于我是句子分类任务所以用的这个。
用这个方法传出来的onnx模型,infer的结果就是正确的。

相关文章

网友评论

      本文标题:Bert 转Onnx 模型时候遇到的坑

      本文链接:https://www.haomeiwen.com/subject/gupxrdtx.html