美文网首页
Pytorch: 转onnx及精度验证

Pytorch: 转onnx及精度验证

作者: wzNote | 来源:发表于2020-06-22 11:30 被阅读0次

    1. 环境配置

    pytorch
    onnxruntime==1.2.0 (1.3.0版本会报错ImportError: cannot import name 'get_all_providers')
    onnxruntime-gpu==1.2.0
    cuda10.1+cudnn7.6

    2. 模型准备和转换

    用torch.save()存储模型结构和权重

    model = torch.load('pix2pix.pth', map_location=torch.device('cuda'))
    

    单卡训练的模型

    torch.onnx._export(model, dummy_input, "pix2pix.onnx", verbose=True, opset_version=11)
    

    多卡训练的模型

    torch.onnx._export(model, dummy_input, "pix2pix.onnx", verbose=True, opset_version=11)
    

    3. 验证是否有精度损失

    import onnxruntime
    import numpy as np
    from onnxruntime.datasets import get_example
    
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    
    # 得到torch模型的输出
    dummy_input = torch.randn(1, 3, 256, 256, device='cuda')
    model.eval()
    with torch.no_grad():
        torch_out = model(dummy_input)
    print(torch_out)
    
    # 得到onnx模型的输出
    example_model = get_example('D:/workspace_python/model_utils/pix2pix.onnx') #一定要写绝对路径
    sess = onnxruntime.InferenceSession(example_model)
    onnx_out = sess.run(None, {input_name: to_numpy(dummy_input)})
    
    # 判断输出结果是否一致,小数点后3位一致即可
    np.testing.assert_almost_equal(to_numpy(torch_out), onnx_out[0], decimal=3)
    

    相关文章

      网友评论

          本文标题:Pytorch: 转onnx及精度验证

          本文链接:https://www.haomeiwen.com/subject/tgkxfktx.html