简介
通过深度学习技术搭建残差网络,使用CompsCars数据集 进行车型识别模型的训练,并将训练好的模型移植到了Android端,实现了通过手机扫一扫的方式进行汽车车型识别的功能。
项目涉及到的技术点较多,需要开发者有一定的技术功底。如:python语言的使用、深度学习框架pytorch的使用、爬虫脚本的理解、Java语言的使用、Android平台架构的理解等等。
虽然属于跨语言开发,但是要求并不高,只要达到入门级别即可看懂本项目,并可以尝试一些定制化的改造。毕竟框架已经搭建好了,只需要修改数据源、重新训练出模型,就可以实现一款新的应用啦。
最终效果
以下视频将展示所有功能完成后的APP的使用情况。
https://www.bilibili.com/video/BV1Pk4y1B7qK
模型训练精度
以下是使用Resnet-34进行400次车型识别训练的 train-validation图表。
loss.png以下是使用Resnet-34进行400次车型识别训练 Top-1的错误率。
top1.png以下是使用Resnet-34进行400次车型识别训练 Top-5的错误率。
top5.png扫一扫识别功能
以下是移植到android平台后进行识别的结果展示图。
UtcItH.png
使用的技术&框架
- 开发语言:Python、Java
- 技术框架:pytorch、resnet-34、Android平台
- 可选借助平台:百度AI平台
- 项目构成:模型训练项目、爬虫项目、APP开发项目
软/硬件需求
机器要求
因为涉及到机器学习模型训练,所以你应该拥有一台用来训练模型的机器,且需要搭载支持CUDA的GPU(如:GeForce、GTX、Tesla等),显存大小,自然是越大越好。
本人项目环境:
- windows10 专业版;GeForce MAX150;独显 2G;1T硬盘
也就是说这是最低配了,你至少要和我同一配置。
开发工具
- Pycharm:用来训练模型、pyhton爬虫、模型移植脚本
- Android Studio:用来开发安卓APP
数据集
数据集是项目最重要的一部分,有了数据集才能开始训练
本项目使用的是 香港中文大学的CompCars细粒度汽车数据集。
CompCars数据集需要的同学可以私聊找我要网盘链接。
各模块介绍
模型训练
Github 地址:pytorch_train 欢迎 star/issue
UtyvKf.png训练模型主要分为五个模块:启动器、自定义数据加载器、网络模型、学习率/损失率调整以及训练可视化。
启动器是项目的入口,通过对启动器参数的设置,可以进行很多灵活的启动方式,下图为部分启动器参数设置。
UtciwD.png任何一个深度学习的模型训练都是离不开数据集的,根据多种多样的数据集,我们应该使用一个方式将数据集用一种通用的结构返回,方便网络模型的加载处理。
Utc9OK.png这里使用了残差网络Resnet-34,代码中还提供了Resnet-18、Resnet-50、Resnet-101以及Resnet-152。残差结构是通过一个快捷连接,极大的减少了参数数量,降低了内存使用。
以下为残差网络的基本结构和Resnet-34 部分网络结构图。
UtcPeO.png Utcn6P.png除了最开始看到的train-val图表、Top-、Top-5的error记录表以外,在训练过程中,使用进度条打印当前训练的进度、训练精度等信息。打印时机可以通过上边提到的 启动器 优雅地配置。
Utc3kQ.png以下为最终的项目包架构。
pytorch_train
|-- data -- 存放读取训练、校验、测试数据路径的txt
| |-- train.txt
| |-- val.txt
| |-- test.txt
|-- result -- 存放最终生成训练结果的目录
|-- util -- 模型移植工具
|-- clr.py -- 学习率
|-- dataset.py -- 自定义数据集
|-- flops_benchmark.py -- 统计每秒浮点运算次数
|-- logger.py -- 日志可视化
|-- mobile_net.py -- 网络模型之一 mobile_net2
|-- resnet.py -- 网络模型之一 Resnet系列
|-- run.py -- 具体执行训练、测试方法
|-- start.py -- 启动器
UtgkuV.png
数据抓取
Github 地址:crawer/dongchedi 欢迎 star/issue
UtyXxP.png最终获取的数据如下图:
Utc8Yj.png模型移植
Github 地址:pytorch_train/transfor
import os
import torch
import torchvision
model_pth = os.path.join("results", "2020-04-27_10-27-17", 'checkpoint.pth.tar')
# 将resnet34模型保存为Android可以调用的文件
mobile_pt = os.path.join("results", "2020-04-27_10-27-17", 'resnet34.pt')
num_class = 13
device = 'cpu' # 'cuda:0' # cpu
model = torchvision.models.resnet34(num_classes=num_class)
model = torch.nn.DataParallel(model, [0])
model.to(device=device)
checkpoint = torch.load(model_pth, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.eval() # 模型设为评估模式
# 1张3通道224*224的图片
input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式
traced_script_module = torch.jit.trace(model.module, input_tensor) # 模型转化
traced_script_module.save(mobile_pt) # 保存文件
安卓界面&数据走向
Github 地址:carIdentify 欢迎 star/issue
UtyO2t.png实现了以下功能:
- 调用摄像头权限自动申请
- 摄像头预览
- 读取pytorch训练模型
- 调用第三方接口,精准预测
最终界面展示:
UtgABT.png安卓项目结构如图:
UtgEHU.png使用方式
启动模型训练
启动前需要确保你已经有了本项目使用的数据集 CompCars
重新开始新的训练
python start.py --data_root "./data" --gpus 0,1,2 -w 2 -b 120 --num_class 13
- --data_root 数据集路径位置
- --gups 使用gpu训练的块数
- -w 为gpu加载自定义数据集的工作线程
- -b 用来gpu训练的 batch size是多少
- --num_class 分类类别数量
使用上次训练结果继续训练
python start.py --data_root "./data" --gpus 0,1,2 -w 2 -b 120 --num_class 13 --resume "results/2020-04-14_12-36-16"
- --data_root 数据集路径位置
- --gups 使用gpu训练的块数
- -w 为gpu加载自定义数据集的工作线程
- -b 用来gpu训练的 batch size是多少
- --num_class 分类类别数量
- --resume 上次训练结果文件夹,可继续上次的训练
模型移植
将训练好的模型转换为Android可以执行的模型
python transfor.py
项目定制化
- 找寻自己的数据集
- 需要修改启动脚本中 --num_class,模型类别
目前项目中具备很多备注记录,稍加review代码就可以理解,如有不清楚,可以私信询问。
启动APP
APP下载链接:https://pan.baidu.com/s/1X7tobj4R302WmGu116-2mg 提取码: 1606
- 安装完成后
- 同意调用系统相机权限
- 使用扫一扫对准汽车
- 稍后将会展示识别后的结果和识别的图片
具体使用方式,可参见:https://www.bilibili.com/video/BV1Pk4y1B7qK
网友评论