1. 环境配置
pytorch
onnxruntime==1.2.0 (1.3.0版本会报错ImportError: cannot import name 'get_all_providers')
onnxruntime-gpu==1.2.0
cuda10.1+cudnn7.6
2. 模型准备和转换
用torch.save()存储模型结构和权重
model = torch.load('pix2pix.pth', map_location=torch.device('cuda'))
单卡训练的模型
torch.onnx._export(model, dummy_input, "pix2pix.onnx", verbose=True, opset_version=11)
多卡训练的模型
torch.onnx._export(model, dummy_input, "pix2pix.onnx", verbose=True, opset_version=11)
3. 验证是否有精度损失
import onnxruntime
import numpy as np
from onnxruntime.datasets import get_example
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# 得到torch模型的输出
dummy_input = torch.randn(1, 3, 256, 256, device='cuda')
model.eval()
with torch.no_grad():
torch_out = model(dummy_input)
print(torch_out)
# 得到onnx模型的输出
example_model = get_example('D:/workspace_python/model_utils/pix2pix.onnx') #一定要写绝对路径
sess = onnxruntime.InferenceSession(example_model)
onnx_out = sess.run(None, {input_name: to_numpy(dummy_input)})
# 判断输出结果是否一致,小数点后3位一致即可
np.testing.assert_almost_equal(to_numpy(torch_out), onnx_out[0], decimal=3)
网友评论