今天花了很多时间来了解的模型部署,作个代码记录。
一,代码
import os
import cv2
import numpy as np
import requests
import torch
import onnx
import torch.onnx
from torch import nn
import onnxruntime
"""
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor):
super().__init__()
self.upscale_factor = upscale_factor
self.img_upsampler = nn.Upsample(
scale_factor=self.upscale_factor,
mode='bicubic',
align_corners=False
)
self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
self.relu = nn.ReLU()
def forward(self,x):
x = self.img_upsampler(x)
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return out
# 为了方便起见,我们跳过训练网络的步骤,直接下载模型权重
# (由于 MMEditing 中 SRCNN 的权重结构和我们定义的模型不太一样,我们修改了权重字典的 key 来适配我们定义的模型),
# 同时下载好输入图片。
# Download checkpoint and test image
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth',
'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']
for url, name in zip(urls, names):
if not os.path.exists(name):
open(name, 'wb').write(requests.get(url).content)
def init_touch_model():
torch_model = SuperResolutionNet(upscale_factor=3)
state_dict = torch.load('srcnn.pth')['state_dict']
for old_key in list(state_dict.keys()):
new_key = '.'.join(old_key.split('.')[1:])
state_dict[new_key] = state_dict.pop(old_key)
torch_model.load_state_dict(state_dict)
torch_model.eval()
return torch_model
model = init_touch_model()
# 为了让模型输出成正确的图片格式,我们把模型的输出转换成 HWC 格式,
# 并保证每一通道的颜色值都在 0~255 之间。
input_img = cv2.imread('bird.jpg').astype(np.float32)
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)
torch_output = model(torch.from_numpy(input_img)).detach().numpy()
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite('face_torch.jpg', torch_output)
# 把 PyTorch 的模型转换成 ONNX 格式的模型
# 其中,torch.onnx.export 是 PyTorch 自带的把模型转换成 ONNX 格式的函数。
# 让我们先看一下前三个必选参数:前三个参数分别是要转换的模型、模型的任意一组输入、导出的 ONNX 文件的文件名。
# 转换模型时,需要原模型和输出文件名是很容易理解的,但为什么需要为模型提供一组输入呢?
# 这就涉及到 ONNX 转换的原理了。
# 从 PyTorch 的模型到 ONNX 的模型,本质上是一种语言上的翻译。
# 直觉上的想法是像编译器一样彻底解析原模型的代码,记录所有控制流。
# 但前面也讲到,我们通常只用 ONNX 记录不考虑控制流的静态图。
# 因此,PyTorch 提供了一种叫做追踪(trace)的模型转换方法:
# 给定一组输入,再实际执行一遍模型,即把这组输入对应的计算图记录下来,保存为 ONNX 格式。
# export 函数用的就是追踪导出方法,需要给任意一组输入,让模型跑起来。
# 我们的测试图片是三通道,256x256大小的,这里也构造一个同样形状的随机张量。
# 剩下的参数中,opset_version 表示 ONNX 算子集的版本。
# 深度学习的发展会不断诞生新算子,为了支持这些新增的算子,ONNX会经常发布新的算子集,目前已经更新15个版本。
# 我们令 opset_version = 11,即使用第11个 ONNX 算子集,是因为 SRCNN 中的 bicubic (双三次插值)在 opset11 中才得到支持。
# 剩下的两个参数 input_names, output_names 是输入、输出 tensor 的名称,我们稍后会用到这些名称。
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
torch.onnx.export(
model,
x,
'srcnn.onnx',
opset_version=11,
input_names=['input'],
output_names=['output']
)
onnx_model = onnx.load('srcnn.onnx')
try:
onnx.checker.check_model(onnx_model)
except Exception as e:
print('Model incorrect.')
else:
print('Model correct.')
"""
input_img = cv2.imread('bird.jpg').astype(np.float32)
input_img = np.transpose(input_img, [2, 1, 0])
input_img = np.expand_dims(input_img, 0)
ort_session = onnxruntime.InferenceSession('srcnn.onnx')
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0]
ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite('bird_org.jpg', ort_output)
二,效果
![](https://img.haomeiwen.com/i23118846/5e3c31dd3094f190.png)
网友评论