美文网首页
Pytorch实现端到端的车牌识别

Pytorch实现端到端的车牌识别

作者: 海盗船长_coco | 来源:发表于2019-12-06 20:57 被阅读0次

    最近大作业做一个车牌识别的项目,于是去github上找了一篇中科大的论文。
    源码地址:https://github.com/detectRecog/CCPD里面包含其论文
    由于训练数据集有15G,所以并没有下载下来跑一遍,而是直接用了作者提供的训练好的权重,发现效果并不理想,而且源代码实现中有几个地方需要修改。
    修改后的代码:https://github.com/872699467/CCPD_CNN

    1、修改roi_pooling.py

    源码roi_pooling.py中

    from torch._thnn import type2backend
    

    发现torch并没有_thnn文件,将其删除并导入

    import torch.nn.functional as F
    

    ,同时修改


    修改部分

    改为

    output.append(F.adaptive_max_pool2d(im, size))
    

    因为torch已经实现了自适应的池化层,感兴趣的同学可以去搜下。那么源码中的AdaptiveMaxPool2d类和adaptive_max_pool方法都可以删除了。

    2、修改demo.py

    如果按照上述的修改,应该能够正确的运行demo.py。但是发现结果中并没有中文的标识符


    检测结果

    代码中有这样一行注释:第一个特征是中文特征,不能正常打印,因此忽略了。

    #   The first character is Chinese character, can not be printed normally, thus is omitted.
    

    所以要修改代码实现将中文显示出来。
    原先代码:

    for i, (XI, ims) in enumerate(trainloader):
        if use_gpu:
            x = Variable(XI.cuda(0))
        else:
            x = Variable(XI)
        # Forward pass: Compute predicted y by passing x to the model
        fps_pred, y_pred = model_conv(x)
        outputY = [el.data.cpu().numpy().tolist() for el in y_pred]
        labelPred = [t[0].index(max(t[0])) for t in outputY]
        [cx, cy, w, h] = fps_pred.data.cpu().numpy()[0].tolist()
        img = cv2.imread(ims[0])
        left_up = [(cx - w/2)*img.shape[1], (cy - h/2)*img.shape[0]]
        right_down = [(cx + w/2)*img.shape[1], (cy + h/2)*img.shape[0]]
        cv2.rectangle(img, (int(left_up[0]), int(left_up[1])), (int(right_down[0]), int(right_down[1])), (0, 0, 255), 2)
        #   The first character is Chinese character, can not be printed normally, thus is omitted.
        lpn = alphabets[labelPred[1]] + ads[labelPred[2]] + ads[labelPred[3]] + ads[labelPred[4]] + ads[labelPred[5]] + ads[labelPred[6]]
        cv2.putText(img, lpn, (int(left_up[0]), int(left_up[1])-20), cv2.FONT_ITALIC, 2, (0, 0, 255))
        cv2.imwrite(ims[0], img)
    

    修改后的代码:

    for i, (XI, ims) in enumerate(trainloader):
        if use_gpu:
            x = XI.cuda(0)
        else:
            x = XI
        # Forward pass: Compute predicted y by passing x to the model
        fps_pred, y_pred = model_conv(x)
        outputY = [el.data.cpu().numpy().tolist() for el in y_pred]
        labelPred = [t[0].index(max(t[0])) for t in outputY]
        [cx, cy, w, h] = fps_pred.data.cpu().numpy()[0].tolist()
        cv2Img = cv2.imread(ims[0])
        left_up = [(cx - w/2)*cv2Img.shape[1], (cy - h/2)*cv2Img.shape[0]]
        right_down = [(cx + w/2)*cv2Img.shape[1], (cy + h/2)*cv2Img.shape[0]]
        cv2.rectangle(cv2Img, (int(left_up[0]), int(left_up[1])), (int(right_down[0]), int(right_down[1])), (0, 0, 255), 2)
        #   The first character is Chinese character, can not be printed normally, thus is omitted.
        lpn = provinces[labelPred[0]]+alphabets[labelPred[1]] + ads[labelPred[2]] + ads[labelPred[3]] + ads[labelPred[4]] + ads[labelPred[5]] + ads[labelPred[6]]
        print('识别结果',lpn)
        # PIL图片上打印汉字
        pilImg=Image.fromarray(cv2.cvtColor(cv2Img,cv2.COLOR_BGR2RGB))
        draw=ImageDraw.Draw(pilImg)
        font = ImageFont.truetype("simhei.ttf", 40, encoding="utf-8")  # 参数1:字体文件路径,参数2:字体大小
        draw.text((int(left_up[0]), int(left_up[1])-40), lpn, (255, 0, 0), font=font)  # 参数1:打印坐标,参数2:文本,参数3:字体颜色,参数4:字体
        # PIL图片转cv2 图片
        cv2charimg = cv2.cvtColor(np.array(pilImg), cv2.COLOR_RGB2BGR)
        # cv2.putText(img, lpn, (int(left_up[0]), int(left_up[1])-20), cv2.FONT_ITALIC, 2, (0, 0, 255))
        dstFileName='result/'+ims[0][-5:-4]+'.jpg'
        cv2.imwrite(dstFileName, cv2charimg)
        print('图片保存地址',dstFileName)
    

    将图片格式转化为PIL库的格式,用PIL的方法写入中文,然后在转化为CV的格式
    结果为:

    修改后的结果
    测试新的数据集,从百度上拿来新的图片来测试,发现效果并不理想
    新的测试图片
    参考博客:https://blog.csdn.net/qq_39622065/article/details/84859629

    相关文章

      网友评论

          本文标题:Pytorch实现端到端的车牌识别

          本文链接:https://www.haomeiwen.com/subject/sxiegctx.html