在模型实际的应用中,一般有两种使用方法,一个是跑批数据,就像我们之前跑验证集那样。比如说我们收集到了很多需要去分类的图像,然后一次性的导入并使用我们训练好的模型给出结果,预测完这一批之后程序就自动关闭了,等到下一次我们有需要的时候再启动。另外一种就是应用于线上服务,构建一个服务等待新的请求,当有请求发起的时候就接收数据,然后给出结果,在没有请求的时候,模型服务仍然处于运行的状态,只不过是等待下一个请求。
Flask框架
关于一次性处理批数据,我们前面的流程基本可以满足了,这里介绍一个在线实时服务。FLask框架是一个用Python编写的Web微服务框架,Flask的使用十分简单,在日常开发中可以快速地实现一个Web服务,而且灵活度很高。
首先安装Flask。
pip install Flask
等待安装完之后,就可以编写代码了,假设我们写一个python脚本名字是flask_hello_world.py,内容如下
from flask import Flask
app = Flask(__name__)
@app.route("/hello")
def hello():
return "Hello World!"
if __name__ == '__main__':
app.run()
然后在shell里面运行它,这里我们在run方法里面没有设置参数,就会使用默认的127.0.0.1 host地址和5000端口,启动成功可以看到下面的显示
image.png
这个时候在浏览器中打开它,输入127.0.0.1:5000/hello,即可看到输出的结果“Hello World!”,这就完成了一个最简单的web服务。
image.png
如果要让它实现模型运算,重点就是去修改hello方法。
import numpy as np
import sys
import os
import torch
from flask import Flask, request, jsonify
import json
from p2ch13.model_cls import LunaModel
app = Flask(__name__)
#加载模型
model = LunaModel()
model.load_state_dict(torch.load(sys.argv[1],
map_location='cpu')['model_state'])
model.eval()
#运行推理部分
def run_inference(in_tensor):
with torch.no_grad():
# LunaModel 接收批量数据并输出一个元组 (scores, probs)
out_tensor = model(in_tensor.unsqueeze(0))[1].squeeze(0)
probs = out_tensor.tolist()
out = {'prob_malignant': probs[1]}
return out
@app.route("/predict", methods=["POST"])
#预测方法的逻辑
def predict():
#使用request接收数据
meta = json.load(request.files['meta'])
blob = request.files['blob'].read()
#转换成tensor
in_tensor = torch.from_numpy(np.frombuffer(
blob, dtype=np.float32))
in_tensor = in_tensor.view(*meta['shape'])
#推理,输出
out = run_inference(in_tensor)
#返回结果
return jsonify(out)
if __name__ == '__main__':
app.run()
print (sys.argv[1])
这样就已经写好了最简单的服务代码,然后运行它
image.png
这时候我们就已经启动了web服务,当然我们这里处理的比较简单,在真实场景下通常都是后台运行,并且要增加日志输出和报警系统,防止出现各种问题而服务中断。然后模拟客户端向服务端发送请求,很快就得到了结果,当然这里有一份预先准备好的数据,不然光数据处理就要花好多时间。
image.png
可以看到恶性肿瘤的可能性不大。到这里,我们就完成了一个简单的模型部署流程,当然,这里只是一个单一的服务,如果我们在工作中需要用到并发服务,异步服务可以在这个基础上进行修改,或者搭配其他的工具。比如说要实现并发服务,我们可以在服务器上启动多个服务,然后搭配Nginx实现负载均衡。
Sanic框架
然后我们再来介绍一个异步处理框架Sanic。现在是一个高并发的时代,并发量是在构建服务时必须考量的一个指标。所以我们自然就想到了 Python 中的异步框架,Sanic 的表现十分出色,使用 Sanic 构建的应用程序足以比肩 Nodejs。如果你再对 Sanic 在路由处理方面使用 C 语言做一些重构,那么并发性能可以和 Go 相媲美。
image.png
异步并发的流程大概像上图描述的样子,多个客户端发起请求,这些请求会进入一个任务队列,然后这些任务的数据组成一个批数据传给模型,模型给出预测结果,然后由请求处理器拆分结果并分别回传给不同的客户端。使用这种方式有助于提高我们的模型工作效率。
首先安装Sanic。
pip install sanic
接下来就是使用sanic完成一个异步服务。我们这里使用的是把马变成斑马的模型。来看看代码,首先是一些引用项。
import sys
import asyncio
import itertools
import functools
from sanic import Sanic
from sanic.response import json, text
from sanic.log import logger
from sanic.exceptions import ServerError
import sanic
import threading
import PIL.Image
import io
import torch
import torchvision
from .cyclegan import get_pretrained_model
定义一些全局变量或者参数。
#实例sanic
app = Sanic(__name__)
#设置使用的设备为cpu
device = torch.device('cpu')
# we only run 1 inference run at any time (one could schedule between several runners if desired)
MAX_QUEUE_SIZE = 3 # 队列最大长度
MAX_BATCH_SIZE = 2 # 批数据的最大长度
MAX_WAIT = 1 # 最大等待时间
异常处理类
class HandlingError(Exception):
def __init__(self, msg, code=500):
super().__init__()
self.handling_code = code
self.handling_msg = msg
模型运行类
class ModelRunner:
def __init__(self, model_name):
#首先是模型运行的初始化
self.model_name = model_name
#声明使用的队列
self.queue = []
#声明队列锁
self.queue_lock = None
#加载模型
self.model = get_pretrained_model(self.model_name,
map_location=device)
#是否运行的标记
self.needs_processing = None
#是否使用计时器
self.needs_processing_timer = None
调度运行信号处理
def schedule_processing_if_needed(self):
#判断队列长度是否已经超过批大小
if len(self.queue) >= MAX_BATCH_SIZE:
logger.debug("next batch ready when processing a batch")
#如果队列长度够长,把运行标记设置为需要运行
self.needs_processing.set()
#否则判断,如果队列不为空,查看计时器
elif self.queue:
logger.debug("queue nonempty when processing a batch, setting next timer")
self.needs_processing_timer = app.loop.call_at(self.queue[0]["time"] + MAX_WAIT, self.needs_processing.set)
处理输入数据并判断是否需要运行
async def process_input(self, input):
our_task = {"done_event": asyncio.Event(loop=app.loop),
"input": input,
"time": app.loop.time()}
async with self.queue_lock:
if len(self.queue) >= MAX_QUEUE_SIZE:
raise HandlingError("I'm too busy", code=503)
self.queue.append(our_task)
logger.debug("enqueued task. new queue size {}".format(len(self.queue)))
self.schedule_processing_if_needed()
#等等处理完成
await our_task["done_event"].wait()
return our_task["output"]
运行模型
def run_model(self, batch):
return self.model(batch.to(device)).to('cpu')
async def model_runner(self):
self.queue_lock = asyncio.Lock(loop=app.loop)
self.needs_processing = asyncio.Event(loop=app.loop)
logger.info("started model runner for {}".format(self.model_name))
#while True 无限循环,程序会处于监听状态
while True:
#等待有任务来
await self.needs_processing.wait()
self.needs_processing.clear()
#清空计时器
if self.needs_processing_timer is not None:
self.needs_processing_timer.cancel()
self.needs_processing_timer = None
#处理队列都开启锁
async with self.queue_lock:
#如果队列不为空则设置最长等待时间
if self.queue:
longest_wait = app.loop.time() - self.queue[0]["time"]
else: # oops
longest_wait = None
#日志记录启动处理,队列大小,等待时间
logger.debug("launching processing. queue size: {}. longest wait: {}".format(len(self.queue), longest_wait))
#获取一个批次的数据
to_process = self.queue[:MAX_BATCH_SIZE]
#然后把这些数据从任务队列中删除
del self.queue[:len(to_process)]
self.schedule_processing_if_needed()
#生成批数据
batch = torch.stack([t["input"] for t in to_process], dim=0)
#在一个单独的线程中运行模型,然后返回结果
result = await app.loop.run_in_executor(
None, functools.partial(self.run_model, batch)
)
#记录结果并设置一个完成事件
for t, r in zip(to_process, result):
t["output"] = r
t["done_event"].set()
del to_process
类实例化
style_transfer_runner = ModelRunner(sys.argv[1])
最后是处理网络交互
#路由策略
@app.route('/image', methods=['PUT'], stream=True)
#处理请求
async def image(request):
try:
#输出报头
print (request.headers)
content_length = int(request.headers.get('content-length', '0'))
#定义接收数据最大值
MAX_SIZE = 2**22 # 10MB
#如果接收数据超标返回异常信息
if content_length:
if content_length > MAX_SIZE:
raise HandlingError("Too large")
#初始化数据接收
data = bytearray(content_length)
else:
data = bytearray(MAX_SIZE)
pos = 0
#这里也是True,一直处于监听状态
while True:
#读取数据包
data_part = await request.stream.read()
if data_part is None:
break
#数据包拼接到data里面
data[pos: len(data_part) + pos] = data_part
pos += len(data_part)
if pos > MAX_SIZE:
raise HandlingError("Too large")
#然后开始对接收的图像数据进行预处理
im = PIL.Image.open(io.BytesIO(data))
im = torchvision.transforms.functional.resize(im, (228, 228))
im = torchvision.transforms.functional.to_tensor(im)
im = im[:3] # drop alpha channel if present
if im.dim() != 3 or im.size(0) < 3 or im.size(0) > 4:
raise HandlingError("need rgb image")
#使用实例化的模型程序处理图像
out_im = await style_transfer_runner.process_input(im)
#结果转化为图像信息
out_im = torchvision.transforms.functional.to_pil_image(out_im)
imgByteArr = io.BytesIO()
out_im.save(imgByteArr, format='JPEG')
return sanic.response.raw(imgByteArr.getvalue(), status=200,
content_type='image/jpeg')
except HandlingError as e:
# we don't want these to be logged...
return sanic.response.text(e.handling_msg, status=e.handling_code)
启动服务部分
app.add_task(style_transfer_runner.model_runner())
app.run(host="0.0.0.0", port=8000,debug=True)
看完代码,我们把它启动起来。
image.png
使用curl把图像数据传到web服务中,并设定了输出结果到res1.jpg中
image.png
去对应的位置查看,果然新生成了一张图片,可见我们的服务运行良好。
image.png
当然这里弄的两个实现方案都挺简单的,不过核心部分基本都介绍到了,在实际的工作中就是在这个基础上修修补补敲敲打打差不多就可以满足需求。
历时一个半月,终于把这本书看完了,英文原版写的挺好,由浅入深,但是这个翻译实在是有点烂,有需要英文原版电子书的留下邮箱。
网友评论