最近大作业做一个车牌识别的项目,于是去github上找了一篇中科大的论文。
源码地址:https://github.com/detectRecog/CCPD里面包含其论文
由于训练数据集有15G,所以并没有下载下来跑一遍,而是直接用了作者提供的训练好的权重,发现效果并不理想,而且源代码实现中有几个地方需要修改。
修改后的代码:https://github.com/872699467/CCPD_CNN
1、修改roi_pooling.py
源码roi_pooling.py中
from torch._thnn import type2backend
发现torch并没有_thnn文件,将其删除并导入
import torch.nn.functional as F
,同时修改
修改部分
改为
output.append(F.adaptive_max_pool2d(im, size))
因为torch已经实现了自适应的池化层,感兴趣的同学可以去搜下。那么源码中的AdaptiveMaxPool2d类和adaptive_max_pool方法都可以删除了。
2、修改demo.py
如果按照上述的修改,应该能够正确的运行demo.py。但是发现结果中并没有中文的标识符
检测结果
代码中有这样一行注释:第一个特征是中文特征,不能正常打印,因此忽略了。
# The first character is Chinese character, can not be printed normally, thus is omitted.
所以要修改代码实现将中文显示出来。
原先代码:
for i, (XI, ims) in enumerate(trainloader):
if use_gpu:
x = Variable(XI.cuda(0))
else:
x = Variable(XI)
# Forward pass: Compute predicted y by passing x to the model
fps_pred, y_pred = model_conv(x)
outputY = [el.data.cpu().numpy().tolist() for el in y_pred]
labelPred = [t[0].index(max(t[0])) for t in outputY]
[cx, cy, w, h] = fps_pred.data.cpu().numpy()[0].tolist()
img = cv2.imread(ims[0])
left_up = [(cx - w/2)*img.shape[1], (cy - h/2)*img.shape[0]]
right_down = [(cx + w/2)*img.shape[1], (cy + h/2)*img.shape[0]]
cv2.rectangle(img, (int(left_up[0]), int(left_up[1])), (int(right_down[0]), int(right_down[1])), (0, 0, 255), 2)
# The first character is Chinese character, can not be printed normally, thus is omitted.
lpn = alphabets[labelPred[1]] + ads[labelPred[2]] + ads[labelPred[3]] + ads[labelPred[4]] + ads[labelPred[5]] + ads[labelPred[6]]
cv2.putText(img, lpn, (int(left_up[0]), int(left_up[1])-20), cv2.FONT_ITALIC, 2, (0, 0, 255))
cv2.imwrite(ims[0], img)
修改后的代码:
for i, (XI, ims) in enumerate(trainloader):
if use_gpu:
x = XI.cuda(0)
else:
x = XI
# Forward pass: Compute predicted y by passing x to the model
fps_pred, y_pred = model_conv(x)
outputY = [el.data.cpu().numpy().tolist() for el in y_pred]
labelPred = [t[0].index(max(t[0])) for t in outputY]
[cx, cy, w, h] = fps_pred.data.cpu().numpy()[0].tolist()
cv2Img = cv2.imread(ims[0])
left_up = [(cx - w/2)*cv2Img.shape[1], (cy - h/2)*cv2Img.shape[0]]
right_down = [(cx + w/2)*cv2Img.shape[1], (cy + h/2)*cv2Img.shape[0]]
cv2.rectangle(cv2Img, (int(left_up[0]), int(left_up[1])), (int(right_down[0]), int(right_down[1])), (0, 0, 255), 2)
# The first character is Chinese character, can not be printed normally, thus is omitted.
lpn = provinces[labelPred[0]]+alphabets[labelPred[1]] + ads[labelPred[2]] + ads[labelPred[3]] + ads[labelPred[4]] + ads[labelPred[5]] + ads[labelPred[6]]
print('识别结果',lpn)
# PIL图片上打印汉字
pilImg=Image.fromarray(cv2.cvtColor(cv2Img,cv2.COLOR_BGR2RGB))
draw=ImageDraw.Draw(pilImg)
font = ImageFont.truetype("simhei.ttf", 40, encoding="utf-8") # 参数1:字体文件路径,参数2:字体大小
draw.text((int(left_up[0]), int(left_up[1])-40), lpn, (255, 0, 0), font=font) # 参数1:打印坐标,参数2:文本,参数3:字体颜色,参数4:字体
# PIL图片转cv2 图片
cv2charimg = cv2.cvtColor(np.array(pilImg), cv2.COLOR_RGB2BGR)
# cv2.putText(img, lpn, (int(left_up[0]), int(left_up[1])-20), cv2.FONT_ITALIC, 2, (0, 0, 255))
dstFileName='result/'+ims[0][-5:-4]+'.jpg'
cv2.imwrite(dstFileName, cv2charimg)
print('图片保存地址',dstFileName)
将图片格式转化为PIL库的格式,用PIL的方法写入中文,然后在转化为CV的格式
结果为:
测试新的数据集,从百度上拿来新的图片来测试,发现效果并不理想
新的测试图片
参考博客:https://blog.csdn.net/qq_39622065/article/details/84859629
网友评论