简单实现YOLOv3 WebAPI
YOLOv3训练自己的数据
YOLOv3安装
先整体看一下文件目录
.
├── darknet.py # 调用darknet框架的脚本
├── libdarknet.so # 编译darknet(yolov3)环境环境时生成的,在编译时文件的根目录,我这里复制过来
├── own
│ ├── pic_box # 结果图片
│ │ └── 20208179501115.jpg
│ ├── pic_save # 上传的图片
│ │ ├── 2020817945515.jpg
│ │ └── 20208179501115.jpg
│ ├── struct # 配置文件
│ │ ├── voc.data
│ │ ├── yolov3.cfg
│ │ ├── yolov3.names
│ │ └── yolov3_test.cfg
│ └── weights # 权重文件
│ └── yolov3_final.weights
├── __pycache__
│ └── darknet.cpython-37.pyc
├── server.py # web服务
└── templates # 页面
├── index.html
└── result.html
使用使用FastAPI部署yolo3目标检测
server.py
# *_* coding: utf-8 *_*
# @File : server.py
# @Author: mihongguang
# @Date : 2020/8/15
'''
使用FastAPI部署yolo3目标检测,通过fastapi存储传递的图片,在通过yolo3识别,然后将识别后的图片返回
'''
import os
from fastapi import FastAPI, File, UploadFile
from fastapi.requests import Request
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse
import darknet as dn
import time
import cv2 as cv
#实例化
app = FastAPI()
#前端模板地址
template = Jinja2Templates(directory="templates")
#设置使用的GPU
dn.set_gpu(0)
print('kuai di loading...')
cwd_path = os.getcwd()
cfg_path = cwd_path + "/own/struct/yolov3_test.cfg"
weights_path = cwd_path + "/own/weights/yolov3_final.weights"
data_path = cwd_path + "/own/struct/voc.data"
#加载配置文件和权重
net_kuaidi = dn.load_net(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
meta_kuaidi = dn.load_meta(data_path.encode('utf-8'))
#展示测试结果并画框
def draw_anchor(img_file, point_data, save_file):
img = cv.imread(img_file)
x1 = int(point_data['x'] - point_data['w'] / 2)
y1 = int(point_data['y'] + point_data['h'] / 2)
x2 = int(point_data['x'] + point_data['w'] / 2)
y2 = int(point_data['y'] - point_data['h'] / 2)
objectname = point_data['class']
cv.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), thickness=2)
cv.putText(img, objectname, (x1, y1), cv.FONT_HERSHEY_COMPLEX, 0.7, (255, 255, 0),
thickness=2)
# cv.imshow('head', img)
cv.imwrite(save_file, img)
#启动darknet来检测图像
def handle_pic(img_path, net_kuaidi, meta_kuaidi):
r_kuaidi = dn.detect(net_kuaidi, meta_kuaidi, img_path.encode('utf-8'))
return r_kuaidi
#配置输出结果
def return_data(outputs):
print("outputs",outputs)
# 这里没有对图片里出现多个快递的情况做处理
if len(outputs) == 0 :
data = dict()
elif len(outputs[0][2]) == 4:
data = {'class': 'express', 'x': outputs[0][2][0], 'y': outputs[0][2][1], 'w': outputs[0][2][2], 'h': outputs[0][2][3]}
else:
data = dict()
return data
#页面提交按钮调用的接口,返回html
@app.post("/upload/")
async def create_upload_file(request:Request,file: UploadFile = File(...)):
content = await file.read()
nowtime = time.localtime(time.time())
name_time = str(nowtime.tm_year) + str(nowtime.tm_mon) + str(nowtime.tm_mday) + str(nowtime.tm_hour) \
+ str(nowtime.tm_min) + str(nowtime.tm_sec)
file_name = name_time + file.filename
#上传检测图片地址
dir_in = cwd_path + "/own/pic_save/" + file_name
#结果图片地址
dir_out = cwd_path + "/own/pic_box/" + file_name
with open(dir_in, "wb") as f:
f.write(content)
outputs_kuaidi = handle_pic(dir_in, net_kuaidi, meta_kuaidi)
# 这一步有些多余
data = {}
data['kuaidi'] = return_data(outputs_kuaidi)
# 展示结果图
if len(data['kuaidi']) > 0 :
draw_anchor(dir_in, data['kuaidi'], dir_out)
else:
with open(dir_out, "wb") as f:
f.write(content)
return template.TemplateResponse("result.html",{"request":request,"imgname": file_name, "result":data})
#返回json数据的API,跟上路由没什么差别
@app.post("/request/")
async def create_upload_file(file: UploadFile = File(...)):
content = await file.read()
nowtime = time.localtime(time.time())
name_time = str(nowtime.tm_year) + str(nowtime.tm_mon) + str(nowtime.tm_mday) + str(nowtime.tm_hour) \
+ str(nowtime.tm_min) + str(nowtime.tm_sec)
file_name = name_time + file.filename
dir_in = cwd_path + "/own/pic_box/" + file_name
with open(dir_in, "wb") as f:
f.write(content)
outputs_kuaidi = handle_pic(dir_in, net_kuaidi, meta_kuaidi)
data = {}
data['kuaidi'] = return_data(outputs_kuaidi)
return {"result":data}
#结果展示页面图片获取
@app.get("/inference/output/{filename}")
async def get_img(filename:str):
dir_out = cwd_path + "/own/pic_box/" + filename
return FileResponse(dir_out)
#index页面
@app.get("/")
async def index(request:Request):
return template.TemplateResponse("index.html", {"request": request})
if __name__ == '__main__':
# 一种服务引擎
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8888,debug=True)
网友评论