美文网首页
队服撞衫?如何让AI区分相似球衣?

队服撞衫?如何让AI区分相似球衣?

作者: adi0229 | 来源:发表于2019-01-03 18:01 被阅读18次

    问题

    笔者最近在学人工智能领域的深度学习技术,在 fast.ai的课程里,我学会了:复用简洁的示例代码来训练卷积神经网络CNN。它能够精准地让机器对图片进行分类。(课程里的例子是识别猫狗图片的二元分类任务,在2000张测试图片中,它的识别准确率达到了99%。)

    我认为,真正的学习,是带着好奇心,自己来解决实际问题,探究输入-输出之间的关系,从而构建自己的知识体系。

    因为自己是多年的足球迷,从兴趣出发,我“凿开”了一个脑洞:能否训练一个卷积神经网络训练,识别区分两件相似足球队衣图片?

    于是,我搜索到了公众号点妹侃球的一篇帖子 《撞衫不可怕,谁丑谁尴尬!这些球衣相似度竟高达99%》,看看都有哪些足球队的队服比较相似。

    其中,我发现,大名鼎鼎的巴塞罗那队服是红蓝箭条衫,而瑞士巴赛尔队的队服也是红蓝箭条衫,两队的主场战袍相似度很高。因此,我打算用 fast.ai的模块来以及预训练的 CNN 卷积神经网络(比如,resnet34)来进行迁移学习,看看我们的 AI 图片分类器,能否较好地识别两个队的队服。

    下面是两队的队服示例:

    巴塞罗那队服 巴塞尔队服

    任务定义:区分图片是「巴塞罗那队队服」还是「巴赛尔队队服」?

    数据集

    • 通过开源的 Python 脚本 google-images-download,根据关键词来批量谷歌图片里的图像。每队批量下载了140张图片,训练集每队约83张图片,验证集每队约60张图片。
    • 巴塞罗那红蓝箭条衫队服
    巴萨球衣图片下载过程 巴萨球衣图片概述
    • 巴赛尔红蓝箭条衫队服
    巴赛尔球衣图片下载过程 巴赛尔球衣图片概述

    训练

    预训练模型

    定义数据路径及图片大小(324*324)

    PATH = "./data/basel_or_barcelona/"
    sz=324
    

    设置预训练模型为resnet34,然后生成模型,学习率设为0.01,并训练10次

    arch=resnet34
    data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz)) 
    learn = ConvLearner.pretrained(arch, data, precompute=True) 
    learn.fit(0.01, 10)
    
    100% 10/10 [00:01<00:00, 5.81it/s]
    epoch      trn_loss   val_loss   accuracy                
        0      0.773914   0.711084   0.586777  
        1      0.694298   0.670647   0.644628                
        2      0.597124   0.62849    0.710744                
        3      0.505373   0.600698   0.735537                
        4      0.450973   0.583853   0.760331        
        5      0.42152    0.595698   0.768595        
        6      0.373794   0.610405   0.785124        
        7      0.33496    0.624578   0.77686                 
        8      0.304601   0.634755   0.77686         
        9      0.277069   0.641988   0.77686     
    

    采用预训练的模型,120多张的测试数据,准确率达到了77.6%,还不错。这是最简单的办法,fast.ai课程还讲授了其他的提示准确率的办法,对于这个小型的足球图片数据集,我来实验一下,看看准确率能否有明显提升。

    寻找最佳学习率

    使用了一个寻找最佳学习率的函数,但是数据图里是空的,没有曲线,暂时无解,先跳过,我决定沿用0.01的学习率。

    数据扩充(Data Argumentation)

    tfms = tfms_from_model(resnet34, sz, aug_tfms=transforms_side_on, max_zoom=1.5)
    

    随机对图片进行水平旋转,并放大1.5倍

    image.png
    data = ImageClassifierData.from_paths(PATH, tfms=tfms)
    learn = ConvLearner.pretrained(arch, data, precompute=True)
    learn.fit(1e-2, 1)
    

    用了 数据扩充(Data Argumentation) 之后,第一次训练,准确率较低,只有52.8%。

    learn.precompute=False
    learn.fit(1e-2, 3, cycle_len=1)
    epoch      trn_loss   val_loss   accuracy                
        0      0.711067   0.799083   0.561983  
        1      0.673241   0.670506   0.652893                
        2      0.632643   0.606632   0.644628
    

    解冻

    之前训练的是最后一层,通过 Unfreeze函数,“解冻”所有神经层,进一步做 Fine-Tuning 参数微调,并且,不同深浅的神经层,采用不同的学习率。

    learn.unfreeze()
    lr=np.array([1e-4,1e-3,1e-2])
    learn.fit(lr, 3, cycle_len=1, cycle_mult=2)
    epoch      trn_loss   val_loss   accuracy                
        0      0.577406   0.585321   0.652893  
        1      0.574191   0.495774   0.727273                
        2      0.505914   0.456579   0.752066                 
        3      0.456115   0.393354   0.801653                
        4      0.408209   0.363914   0.818182                
        5      0.370777   0.353467   0.834711                
        6      0.342305   0.350941   0.842975 
    
    learn.fit(lr, 6, cycle_len=1, cycle_mult=2)
    epoch      trn_loss   val_loss   accuracy                
        0      0.224649   0.351165   0.842975  
        1      0.201187   0.343334   0.867769                
        2      0.192907   0.339028   0.867769                
        3      0.187865   0.32258    0.867769                
        4      0.167363   0.309566   0.867769                
        5      0.165586   0.302056   0.859504                
        6      0.162697   0.303541   0.859504                
        7      0.160678   0.301258   0.867769                
        8      0.150494   0.315228   0.876033                
        9      0.149012   0.333124   0.876033                
        10     0.140085   0.341198   0.884298                
        11     0.133135   0.343363   0.884298                
        12     0.125423   0.339807   0.884298                
        13     0.1168     0.33534    0.884298                
        14     0.109273   0.33899    0.884298                
        15     0.10801    0.32313    0.884298                
        16     0.101856   0.313022   0.892562                
        17     0.099126   0.29976    0.884298                 
        18     0.094342   0.293067   0.884298                 
        19     0.090559   0.291236   0.884298                 
        20     0.089409   0.294657   0.884298                 
        21     0.085061   0.292772   0.884298                 
        22     0.080943   0.294916   0.884298                 
        23     0.077687   0.289591   0.884298                 
        24     0.073291   0.290864   0.884298                  
        25     0.070136   0.289896   0.884298                  
        26     0.071041   0.291556   0.876033                 
        27     0.067812   0.287182   0.884298                  
        28     0.06408    0.287058   0.884298                 
        29     0.062913   0.288546   0.884298                  
        30     0.060431   0.286025   0.884298                 
        31     0.060556   0.284598   0.884298                 
        32     0.058281   0.290405   0.884298                 
        33     0.056919   0.295286   0.892562                  
        34     0.054588   0.29396    0.900826                 
        35     0.052521   0.292013   0.892562                 
        36     0.051742   0.281408   0.892562                  
        37     0.050915   0.274899   0.892562                  
        38     0.048206   0.265904   0.892562                 
        39     0.048089   0.2706     0.892562                 
        40     0.045976   0.286523   0.892562                 
        41     0.046062   0.297521   0.892562                  
        42     0.043593   0.297748   0.900826                 
        43     0.041344   0.293669   0.900826                 
        44     0.040934   0.297795   0.900826                 
        45     0.039792   0.300095   0.900826                 
        46     0.038014   0.300615   0.900826                 
        47     0.037288   0.298393   0.892562                 
        48     0.037229   0.294151   0.900826                 
        49     0.035728   0.290464   0.884298                 
        50     0.034198   0.286805   0.884298                 
        51     0.033855   0.286428   0.876033                 
        52     0.032238   0.282124   0.884298                 
        53     0.031438   0.276031   0.892562                 
        54     0.031037   0.282506   0.876033                 
        55     0.02959    0.285489   0.884298                 
        56     0.028276   0.282334   0.892562                 
        57     0.027054   0.278635   0.892562                 
        58     0.025802   0.281353   0.892562                 
        59     0.024861   0.282644   0.892562                 
        60     0.023592   0.281519   0.892562                 
        61     0.023124   0.283864   0.884298                 
        62     0.022198   0.284992   0.892562
    
    log_preds,y = learn.TTA()
    probs = np.mean(np.exp(log_preds),0)
    accuracy_np(probs, y)
    0.8925619834710744
    

    采用了 TTA(测试集也使用 Data Argumentation) 之后,最后的准确率达到了89.2%。

    结果

    混淆矩阵(Confusion Matrix)

    画出一个混淆矩阵——查看不同分类,识别错误的图片都有哪些?

    Confusion Matrix

    看图可知

    巴塞罗那队服:54张正确,6张错误。
    巴塞尔队服:54张正确,7张错误。

    下面我们看看具体哪些图片识别错了:

    第一行几张图片,判断错误还是有点奇怪的,因为图片本身有着巴塞罗那队的标志,标志人物梅西和队徽。也许是训练集还太小的缘故吧。

    第二行判断错误的巴赛尔队服,第二张是异常图片,呈黑白色。其他三张与训练集里典型的巴赛尔队服还是有点区别的。这样也许解释得过去。

    参考资料:

    Fun with small image data-sets
    Case Study: A world class image classifier for dogs and cats
    fastai-doc

    相关文章

      网友评论

          本文标题:队服撞衫?如何让AI区分相似球衣?

          本文链接:https://www.haomeiwen.com/subject/qseplqtx.html