美文网首页
TensorRT Python验证代码

TensorRT Python验证代码

作者: 教训小磊 | 来源:发表于2023-08-20 15:52 被阅读0次
    
    import tensorrt as trt
    import numpy as np
    import os
    import cv2
    
    import pycuda.driver as cuda
    import pycuda.autoinit
    
    
    class HostDeviceMem(object):
        def __init__(self, host_mem, device_mem):
            self.host = host_mem
            self.device = device_mem
    
        def __str__(self):
            return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
    
        def __repr__(self):
            return self.__str__()
    
    
    class TrtModel:
    
        def __init__(self, engine_path, max_batch_size=1, dtype=np.float32):
    
            self.engine_path = engine_path
            self.dtype = dtype
            self.logger = trt.Logger(trt.Logger.WARNING)
            self.runtime = trt.Runtime(self.logger)
            self.engine = self.load_engine(self.runtime, self.engine_path)
            self.max_batch_size = max_batch_size
            self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
            self.context = self.engine.create_execution_context()
    
        @staticmethod
        def load_engine(trt_runtime, engine_path):
            trt.init_libnvinfer_plugins(None, "")
            with open(engine_path, 'rb') as f:
                engine_data = f.read()
            engine = trt_runtime.deserialize_cuda_engine(engine_data)
            return engine
    
        def allocate_buffers(self):
    
            inputs = []
            outputs = []
            bindings = []
            stream = cuda.Stream()
    
            for binding in self.engine:
                size = trt.volume(self.engine.get_binding_shape(binding)) * self.max_batch_size
                host_mem = cuda.pagelocked_empty(size, self.dtype)
                device_mem = cuda.mem_alloc(host_mem.nbytes)
    
                bindings.append(int(device_mem))
    
                if self.engine.binding_is_input(binding):
                    inputs.append(HostDeviceMem(host_mem, device_mem))
                else:
                    outputs.append(HostDeviceMem(host_mem, device_mem))
    
            return inputs, outputs, bindings, stream
    
        def __call__(self, x: np.ndarray, batch_size=2):
    
            x = x.astype(self.dtype)
    
            np.copyto(self.inputs[0].host, x.ravel())
    
            for inp in self.inputs:
                cuda.memcpy_htod_async(inp.device, inp.host, self.stream)
    
            self.context.execute_async(batch_size=batch_size, bindings=self.bindings, stream_handle=self.stream.handle)
            for out in self.outputs:
                cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
    
            self.stream.synchronize()
            return [out.host.reshape(batch_size, -1) for out in self.outputs]
    
    
    if __name__ == "__main__":
    
        trt_engine_path = r'./trt/cls-smi.engine'
        pic_path=r'./trt/11.bmp'
        w,h=112,112
    
    
        mean = (127.5, 127.5, 127.5)
        std = (127.5, 127.5, 127.5)
    
        lables_cls = {0: 'background',
                      1: 'QPZZ',
                      2: 'MDBD',
                      3: 'MNYW',
                      4: 'WW',
                      5: 'LMPS',
                      6: 'BMQQ',
                      7: 'LMHH',
                      8: 'KTAK',
                      }
    
        # 输入图像预处理
        img = cv2.imread(pic_path)
        img = cv2.resize(img, (w, h))
        img = img[:, :, ::-1]
        img = np.array(img).astype(np.float32)  # 注意输入type一定要np.float32
        img -= mean
        img /= std
        img = np.array([np.transpose(img, (2, 0, 1))])
    
        # 模型推理
        model = TrtModel(trt_engine_path)
        # shape = model.engine.get_binding_shape(0)
        result = model(img, 1)
    
        score=result[0][0][0]
        index=int(result[1][0][0])
        class_name=lables_cls[index]
        print('{}:{:.4f}'.format(class_name,score))
    

    相关文章

      网友评论

          本文标题:TensorRT Python验证代码

          本文链接:https://www.haomeiwen.com/subject/vewsmdtx.html