比赛 | 京东AI时尚单品检索:一次小小的尝试

作者: 与阳光共进早餐 | 来源:发表于2018-07-20 12:42 被阅读15次

    一 写在前面

    未经允许,不得转载,谢谢。

    前前后后大概集中做了20来天,第一次尝试竞赛类型的项目,终于磕磕碰碰的走完了整个过程,虽然没有取得一个好结果,但也学习了很多的东西。

    想来想去还是将整个项目的开展情况以及重点做一个梳理吧,就当画个句号。

    期待后期能学习到获奖团队的方案方法,用以反思借鉴。

    本着不能浪费小伙伴宝贵的阅读时间的原则,再说一下本篇文章仅仅是个人对于比赛的一次尝试,是否值得借鉴就请小伙伴自己评估啦~~~

    如果有同样做过这个比赛的大佬路过,还希望不吝赐教哇。(〃'▽'〃)

    二 题目介绍及理解

    1主要任务

    1. 主要想法:以图搜图;
    2. 找到与实拍图最相近的电商展示图,返回匹配度最高的10个ID;
    3. 每个时尚单品都属于{上衣,鞋子,箱包}中的一类;
    4. 每个时尚单品均需要根据单品坐标进行抠图处理,每张图片的给定区域只有一个物品;

    2 数据集

    • 训练数据:1.2万个真实的京东时尚单品图片对P
    • 训练数据:{电商展示图+用户实拍图}
    • 测试数据:2000张用户实拍图
    • 15万张全部电商展示图S
    • 都经过抠图,指定好时尚单品的位置坐标

    3 基本思路

    1. 根据提供的URL获取所有的图片;
    2. 根据给出的坐标,抠出图像中的物体小图片;
    3. 对所有的物品通过网络模型提取特征;
    4. 找出跟自己特征最相似的图片作为检索结果;
    5. 根据测试数据对用深度学习调整模型结构及参数;

    处理这类问题的几个大步骤都是类似的,基本可以分成:

    1. 数据集处理;
    2. 模型设计与实现;
    3. 训练、测试、优化、再重复2.3

    篇幅有限,太细节的东西就不再写了。挑一些觉得对自己以后做项目或者对大家会有帮助的东西吧。

    三 数据集准备及预处理

    1. 根据url下载图片

    数据集很大的情况,常常需要我们自己去下载图片,这个时候就需要有个程序帮我们自动下载了。

    • 用urllib获取图片并保存
    import urllib
    # img_url: the url of image
    # img_path: the path you want to save image
    urllib.urlretrieve(img_url,img_path)
    

    2. 图片加载与处理

    1. 用PIL加载图像

    from PIL import Image
    
    def get_image_from_path(img_path,img_region):
        image = Image.open(img_path)
        image = process_image_channels(image, img_path)
        image = image.crop(img_region)
        return image
    

    2. 关于crop函数

    • 一定要注意bounding_box的传入参数;
    • crop接受的参数为(左上x,左上y,右下x,右下y)
    • python的坐标系为最左上角为(0,0),横向x,纵向y;
    • 这里踩了好久的坑。╮(╯﹏╰)╭

    3. 关于处理图像通道

    • 在这次处理的数据集中有jpg的图像,也有png的图像;
    • 以前从来不知道png会有RGBA4个通道甚至有些图片只有一个A通道,所以如果没有提前处理后面训练或者换测试的时候会时不时的给你一个bug小彩蛋哈哈哈。
    • 关键语句:
    def process_image_channels(image, image_path):
        # process the 4 channels .png
        if image.mode == 'RGBA':
            r, g, b, a = image.split()
            image = Image.merge("RGB", (r, g, b))
        # process the 1 channel image
        elif image.mode != 'RGB':
            image = image.convert("RGB")
            os.remove(image_path)
            image.save(image_path)
        return image
    

    3. 训练数据集划分

    • 在这类具体的情况中我们往往之后只有训练数据;
    • 通常将训练数据分成:90%训练集 + 10%测试集;
    • 这一点对于新手还是比较容易被忽略的。

    四 理论学习与模型确定

    1. 主要参考资料

    2. 论文重点整理博客

    把两篇论文的重点整理在这里了,如果不想看原文的话也可以看看这2篇博客:

    3 确定baseline模型及基本方案

    1. baseline模型

    • 选择了facenet中的网络结果作为基本的baseline;


    2. 基本方案

    • 深度网络模型: 采用在ImageNet上预训练过的DenseNet161
    • 用Densenet161对输入图像提取特征;
    • 训练过程用triplet loss进行训练;
    • 验证及测试直接用向量之间的余弦相似度来表示,选择最相似的top-10作为输出结果。

    五 代码设计与编写

    1. 代码构成

    • 先上一张整个项目的构成图吧。


    • 其中以a_开头的是重构后的代码,也就是最终使用的。

    2. 各个文件说明

    我会对各个文件的构成做一个简单的介绍,然后将重点的部分单独整理成blog,并给出链接。

    1. a_utils_fyq.py
      用于放一些经常会被使用的,便于减少其他代码文件的工作量,增强代码可读性。包括:

      • 全局路径
      • 一些通用函数(例如根据图片地址获取图片)
      • 或者完全可以独立于其他类的函数(例如多行向量之间余弦距离的计算)。
    2. a_network_fyq.py

      • 定义了图像检索任务中用来提取图像特征的网络结构。
      • 直接用了预训练好的densenet161,然后提取它的features作为输出。
      • 用pytorch可以很简单的实现网络模型的定义;
      • 关键代码:
    class Net(nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            self.denseNet161=torchvision.models.densenet161(pretrained=True)
    
        def forward(self,x):
            x = self.denseNet161.features(x)
            out = x.view(x.size(0), -1)
            return out
    
    1. a_dataloader_val_fyq.py
      • 用于在验证时加载单张图片;
      • 同时适用于query和rep的图片加载;
    2. a_dataloader_train_fyq.py
    3. a_train_nework_fyq.py
      • 用于训练图像检索神经网络
      • 网络训练主要包括:获取到批量数据,然后前向传播、计算loss值,更新参数即可。
      • 可以参见:Pytorch入门学习(四)-training a classifier
      • triplet_loss的实现也可以通过pytorch自带的函数很方便的实现。
      • 关键代码:
        def triplets_loss(self,anchor,positive,negative):
            triplet_loss=nn.TripletMarginLoss(margin=self.margin,p=2)
            output_loss=triplet_loss(anchor,positive,negative)
            return output_loss
    
    1. a_image_retrieval.fy
      • 用于通过模型得到query以及rep的特征值、计算多行向量之间的余弦距离,然后排序得到top-10结果。
      • 网上有很多向量之间计算余弦距离的代码可以参考,但是我还没有找到基于矩阵的。所以就自己琢磨着写了一下,也分享给大家。
      • 多行向量求余弦距离:
        def cosine_distance(self, matrix1, matrix2):
            matrix1_matrix2 = np.dot(matrix1, matrix2.transpose())
            matrix1_norm = np.sqrt(np.multiply(matrix1, matrix1).sum(axis=1))
            matrix1_norm = matrix1_norm[:, np.newaxis]
            matrix2_norm = np.sqrt(np.multiply(matrix2, matrix2).sum(axis=1))
            matrix2_norm = matrix2_norm[:, np.newaxis]
            cosine_distance = np.divide(matrix1_matrix2, np.dot(matrix1_norm, matrix2_norm.transpose()))
            return cosine_distance
    
    1. a_classifer_fyq.py
      • 训练了一个分类器
      • 这个文件像是一个小型的项目,囊括了分类任务的模型定义、数据加载、网络训练、网络测试。不过这个也都是大同小异的过程,写过一个就可以触类旁通了。

    六 写在最后

    整体的实现过程就是这样,感觉应该也已经把最具有学习价值的东西提炼出来了。

    模型最后的精确度并不高,所以细节的训练以及测试的结果并不具有什么可参考的价值,就不再整理了。

    1. 关于triplet loss训练中样本的选择问题

    • 随机选择负样本真的效果很不好,亲测!
    • 在minibatch中在线选择semi-hard或者hard负样本的效果也不好,因为常常会受制于batch_size的大小,也是亲测!
    • 建议采用线下选择负样本的方式。
    • 总的来说,用triplet loss来训练检索模型对于样本的选择很重要,不然可能会出现像我一样越train越差的尴尬局面...

    2. 关于项目运行的一点经验

    • 有的时候需要对同一个模型做条件不同的测试,就可以用多脚本的方式,不容易搞混。
    • 在多块GPU的机器上运行代码,可以用CUDA_VISIBLE_DEVICES=" "的方式指定使用某一块GPU;
    • 数据量很大如果导致显卡溢出了,可以采用将数据移到CPU的解决方式,myVariable.cpu().numpy();
    • CPU如果还是内存溢出,那就只能选择手动分段来结果问题了。

    3. 关于项目代码的一点小感触

    这是我自己第一次完完全全一行一行地编写完一个完整的项目,算是这次学到的很宝贵的财富了吧。

    自己重构过几遍代码真的很重要,会更好的知道怎么样把一个大的事情一点一点的拆小,也会慢慢形成一个自己写代码的风格与习惯。

    看自己写的东西最酥糊了嘻嘻⁄(⁄⁄•⁄ω⁄•⁄⁄)⁄

    4. 关于科研的一点感悟

    想点子、编程实现、实验,然后重复循环。

    这或许就是科研本来的样子,希望自己能不断接纳挫折,接纳失败,然后接纳成果。

    能想到可以分享的大概就这些了。

    与你们共勉。

    相关文章

      网友评论

      • eec58ae62b57:楼主您好,冒昧的打扰了,请问你还有训练集吗,因为比赛已经截止,下载数据需要百度云的提取码,我不太清楚
        与阳光共进早餐:@xiaoxuJY 数据集太大,我已经没有保存了~~
        xiaoxuJY:楼主,数据集能分享一下吗?
        与阳光共进早餐:@漂漂一族 不好意思哈,刚看到评论。

      本文标题:比赛 | 京东AI时尚单品检索:一次小小的尝试

      本文链接:https://www.haomeiwen.com/subject/vdkxmftx.html