pytorch onnx onnxruntime tensorr

作者: qizhen816 | 来源:发表于2020-01-06 19:20 被阅读0次

    做了一个小测试,发现pytorch onnx tensorrt三个库的版本存在微妙的联系,在我之前的错误实验中,PyTorch==1.3.0/1.4.0;Onnx==1.6.0;tensorrt=7.0,用以下包含一个上采样层的代码做测试:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import os
    class TestModel(nn.Module):
        def __init__(self):
            super(TestModel, self).__init__()
        def forward(self, x):
            x = F.interpolate(x, (256, 256), mode = 'bilinear')
            return x
    torch_model = TestModel()
    dummy_input = torch.randn((1, 3, 256, 256))
    torch_out = torch.onnx.export(torch_model,
                                  dummy_input,
                                  'test_model.onnx',
                                  verbose=True,
                                  opset_version=11,)
    

    得到的Onnx模型:

    %2 : Tensor = onnx::Constant[value=[ CPUFloatType{0} ]]()
    %3 : Tensor = onnx::Constant[value= 1  3 [ CPULongType{2} ]]()
    %4 : Tensor = onnx::Cast[to=7](%1)
    %5 : Tensor = onnx::Concat[axis=0](%3, %4)
    
    1. 遇到的第一个错误,使用onnx.checker.check_model(onnx_model),
      Segmentation fault (core dumped)
      解决:在import torch之前import onnx,二者的前后顺序要注意
    2. onnxruntime的图片测试Unexpected input data type. Actual: (N11onnxruntime17PrimitiveDataTypeIdEE) , expected: (N11onnxruntime17PrimitiveDataTypeIfEE)
      解决:是传入Onnx模型的数据类型不对,换成np.float32试试。
            #一个语义分割网络onnx测试
            import onnx
            import onnxruntime
            import cv2
            img = cv2.imdecode(np.fromfile('test.jpg',dtype=np.uint8),-1)
            img = cv2.resize(img, (768,768))
            img = np.expand_dims(img,axis=0).astype(np.float32)/255
            img = img.transpose(0,3,1,2) #格式 Batch, Chanel, Height, Width
            ort_session = onnxruntime.InferenceSession('test.onnx')
            ort_inputs = {ort_session.get_inputs()[0].name: (image),} #类似tensorflow的传入数据,有几个输入就写几个
            ort_outs = ort_session.run(None, ort_inputs)
            mask = np.argmax(ort_outs[0], 1).squeeze().astype(np.int8)
            cv2.imwrite("result.jpg",mask*255)
    
    1. 不论是用编译好的onnx2trt,还是TensorRT直接读取,都会报一个错:
    pytorch 1.4.0
    In node 5 (parseGraph): INVALID_GRAPH: Assertion failed: ctx->tensors().count(inputName)
    

    大致就是5号节点的输入计数不正确,存在一些没有输入的叶子结点,用netron读取显示为:

    Onnx结构
    很明显,这个Constant就是多余的输入节点。
    解决:目前没有好的解决办法 设置opset_version=10,使用nearest上采样可以运行
    更新:在https://github.com/NVIDIA/TensorRT/issues/284,开发者回复说 TensorRT only supports assymetric resizing at the moment,也就是说nearest是可以用的,但是bilinear上采样还没有得到TensorRT的支持。

    相关文章

      网友评论

        本文标题:pytorch onnx onnxruntime tensorr

        本文链接:https://www.haomeiwen.com/subject/fsakactx.html