美文网首页
pytorch的模型部署入门

pytorch的模型部署入门

作者: 万州客 | 来源:发表于2022-06-27 18:16 被阅读0次

今天花了很多时间来了解的模型部署,作个代码记录。

一,代码

import os
import cv2
import numpy as np
import requests
import torch
import onnx
import torch.onnx
from torch import nn
import onnxruntime

"""
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False
        )
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.img_upsampler(x)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out



# 为了方便起见,我们跳过训练网络的步骤,直接下载模型权重
# (由于 MMEditing 中 SRCNN 的权重结构和我们定义的模型不太一样,我们修改了权重字典的 key 来适配我们定义的模型),
# 同时下载好输入图片。
# Download checkpoint and test image
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth',
        'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']

for url, name in zip(urls, names):
    if not os.path.exists(name):
        open(name, 'wb').write(requests.get(url).content)

def init_touch_model():
    torch_model = SuperResolutionNet(upscale_factor=3)
    state_dict = torch.load('srcnn.pth')['state_dict']
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)
    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    return torch_model


model = init_touch_model()

# 为了让模型输出成正确的图片格式,我们把模型的输出转换成 HWC 格式,
# 并保证每一通道的颜色值都在 0~255 之间。
input_img = cv2.imread('bird.jpg').astype(np.float32)
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)

torch_output = model(torch.from_numpy(input_img)).detach().numpy()

torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)

cv2.imwrite('face_torch.jpg', torch_output)

# 把 PyTorch 的模型转换成 ONNX 格式的模型
# 其中,torch.onnx.export 是 PyTorch 自带的把模型转换成 ONNX 格式的函数。
# 让我们先看一下前三个必选参数:前三个参数分别是要转换的模型、模型的任意一组输入、导出的 ONNX 文件的文件名。
# 转换模型时,需要原模型和输出文件名是很容易理解的,但为什么需要为模型提供一组输入呢?
# 这就涉及到 ONNX 转换的原理了。
# 从 PyTorch 的模型到 ONNX 的模型,本质上是一种语言上的翻译。
# 直觉上的想法是像编译器一样彻底解析原模型的代码,记录所有控制流。
# 但前面也讲到,我们通常只用 ONNX 记录不考虑控制流的静态图。
# 因此,PyTorch 提供了一种叫做追踪(trace)的模型转换方法:
# 给定一组输入,再实际执行一遍模型,即把这组输入对应的计算图记录下来,保存为 ONNX 格式。
# export 函数用的就是追踪导出方法,需要给任意一组输入,让模型跑起来。
# 我们的测试图片是三通道,256x256大小的,这里也构造一个同样形状的随机张量。
# 剩下的参数中,opset_version 表示 ONNX 算子集的版本。
# 深度学习的发展会不断诞生新算子,为了支持这些新增的算子,ONNX会经常发布新的算子集,目前已经更新15个版本。
# 我们令 opset_version = 11,即使用第11个 ONNX 算子集,是因为 SRCNN 中的 bicubic (双三次插值)在 opset11 中才得到支持。
# 剩下的两个参数 input_names, output_names 是输入、输出 tensor 的名称,我们稍后会用到这些名称。

x = torch.randn(1, 3, 256, 256)

with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        'srcnn.onnx',
        opset_version=11,
        input_names=['input'],
        output_names=['output']
    )

onnx_model = onnx.load('srcnn.onnx')
try:
    onnx.checker.check_model(onnx_model)
except Exception as e:
    print('Model incorrect.')
else:
    print('Model correct.')
"""
input_img = cv2.imread('bird.jpg').astype(np.float32)
input_img = np.transpose(input_img, [2, 1, 0])
input_img = np.expand_dims(input_img, 0)

ort_session = onnxruntime.InferenceSession('srcnn.onnx')
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0]


ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite('bird_org.jpg', ort_output)

二,效果

2022-06-27 17_26_41-悬浮球.png

相关文章

网友评论

      本文标题:pytorch的模型部署入门

      本文链接:https://www.haomeiwen.com/subject/pjlovrtx.html