pytorch onnx onnxruntime tensorr

作者: qizhen816 | 来源:发表于2020-01-06 19:20 被阅读0次

做了一个小测试,发现pytorch onnx tensorrt三个库的版本存在微妙的联系,在我之前的错误实验中,PyTorch==1.3.0/1.4.0;Onnx==1.6.0;tensorrt=7.0,用以下包含一个上采样层的代码做测试:

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
    def forward(self, x):
        x = F.interpolate(x, (256, 256), mode = 'bilinear')
        return x
torch_model = TestModel()
dummy_input = torch.randn((1, 3, 256, 256))
torch_out = torch.onnx.export(torch_model,
                              dummy_input,
                              'test_model.onnx',
                              verbose=True,
                              opset_version=11,)

得到的Onnx模型:

%2 : Tensor = onnx::Constant[value=[ CPUFloatType{0} ]]()
%3 : Tensor = onnx::Constant[value= 1  3 [ CPULongType{2} ]]()
%4 : Tensor = onnx::Cast[to=7](%1)
%5 : Tensor = onnx::Concat[axis=0](%3, %4)
  1. 遇到的第一个错误,使用onnx.checker.check_model(onnx_model),
    Segmentation fault (core dumped)
    解决:在import torch之前import onnx,二者的前后顺序要注意
  2. onnxruntime的图片测试Unexpected input data type. Actual: (N11onnxruntime17PrimitiveDataTypeIdEE) , expected: (N11onnxruntime17PrimitiveDataTypeIfEE)
    解决:是传入Onnx模型的数据类型不对,换成np.float32试试。
        #一个语义分割网络onnx测试
        import onnx
        import onnxruntime
        import cv2
        img = cv2.imdecode(np.fromfile('test.jpg',dtype=np.uint8),-1)
        img = cv2.resize(img, (768,768))
        img = np.expand_dims(img,axis=0).astype(np.float32)/255
        img = img.transpose(0,3,1,2) #格式 Batch, Chanel, Height, Width
        ort_session = onnxruntime.InferenceSession('test.onnx')
        ort_inputs = {ort_session.get_inputs()[0].name: (image),} #类似tensorflow的传入数据,有几个输入就写几个
        ort_outs = ort_session.run(None, ort_inputs)
        mask = np.argmax(ort_outs[0], 1).squeeze().astype(np.int8)
        cv2.imwrite("result.jpg",mask*255)
  1. 不论是用编译好的onnx2trt,还是TensorRT直接读取,都会报一个错:
pytorch 1.4.0
In node 5 (parseGraph): INVALID_GRAPH: Assertion failed: ctx->tensors().count(inputName)

大致就是5号节点的输入计数不正确,存在一些没有输入的叶子结点,用netron读取显示为:

Onnx结构
很明显,这个Constant就是多余的输入节点。
解决:目前没有好的解决办法 设置opset_version=10,使用nearest上采样可以运行
更新:在https://github.com/NVIDIA/TensorRT/issues/284,开发者回复说 TensorRT only supports assymetric resizing at the moment,也就是说nearest是可以用的,但是bilinear上采样还没有得到TensorRT的支持。

相关文章

网友评论

    本文标题:pytorch onnx onnxruntime tensorr

    本文链接:https://www.haomeiwen.com/subject/fsakactx.html