美文网首页
源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理

源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理

作者: CW不要无聊的风格 | 来源:发表于2020-06-29 14:05 被阅读0次

    Date: 2020/06/28

    Author: CW

    前言:

    本文会对模型训练部分的代码进行解析,主要把训练过程的pipeline过一下,其中有些部分的具体实现会放到后面的篇章中讲解,这部分源码对应于项目的 main.py 文件。


    Outline

    1、训练pipeline

    2、一个训练周期的过程

    3、数据处理


    训练pipeline

    首先是解析运行脚本时用户输入的参数,然后创建记录结果的目录,最后根据用户指定的参数进行训练。

    pipeline整体过程

    get_args_parser() 方法中设置了用户可以指定的参数项,想要了解的朋友们可以去参考源码,这里就不再阐述了,接下来主要看看 main() 方法,其中的内容就是训练过程的pipeline。

    init_distributed_mode() 方法是与分布式训练相关的设置,在该方法里,是通过环境变量来判断是否使用分布式训练,如果是,那么就设置相关参数,具体可参考 util/misc.py 文件中的源码,这里不作解析。

    训练pipeline(i)

    参数项 frozen_weights 代表是否固定住参数的权重,类似于迁移学习的微调。如果是,那么需要同时指定 masks 参数,代表这种条件仅适用于分割任务。上图最后部分是固定随机种子,以便复现结果。

    然后就是构造模型、loss函数以及后处理方法、输出可训练的参数数量。

    训练pipeline(ii)

    下图的部分包括设置优化器、学习率策略以及构建训练和验证集。

    训练pipeline(iii)

    从上图可以看到,这里将backbone和其它部分的参数分开,以便使用不同的初始学习率进行训练。构造数据集使用的 build_dataset() 方法调用了COCO数据集的api,其中的内容具体会在后文展示。

    构造了数据集后,设置数据集的采样器,并且装在到 DataLoader,以进行批次训练。

    训练pipeline(iv)

    注意到以上使用了 collate_fn 方法来重新组装一个batch的数据,具体细节会在后面数据处理部分一并讲解。

    训练pipeline(v)

    下图部分主要是用于从历史的某个训练阶段中恢复过来,包括加载当时的模型权重、优化器和学习率等参数。

    训练pipeline(vi) 训练pipeline(vii)

    接下来真正开始一个个周期地训练,每个周期后根据学习率策略调整下学习率。

    训练pipeline(viii)

    下图部分是将训练结果和相关参数记录到指定文件。

    训练pipeline(ix) 训练pipeline(x)

    下图中的内容是将训练和验证的结果记录到(分布式)主节点中指定的文件。

    训练pipeline(xi)

    最后计算训练的总共耗时并且打印,整个训练流程就此结束。

    训练pipeline(xii)

    一个训练周期的过程

    这部分对应的代码在 detr/engine.py 中的 train_one_epoch() 方法,在上一节的图中也能看到。顾名思义,这部分内容就是模型在一个训练周期中的操作,下面就来一起瞄瞄里面有啥值得学习的地方。

    惯用套路,首先将模型设置为训练模式,这样梯度才能进行反向传播,从而更新模型参数的权重。注意到这里同时将 criterion 对象也设为train模式,它是 SetCriterion 类的一个对象实例,代表loss函数,看了下相关代码发现里面并没有需要学习的参数,因此感觉之类可以将这行代码去掉,后面我会亲自实践看看,朋友们也可一试。

    train_one_epoch(i)

    这里用到了一个类 MetricLogger(位于 detr/util/misc.py),它主要用于log输出,其中使用了一个defaultdict来记录各种数据的历史值,这些数据为 SmoothValue(位于 detr/util/misc.py) 类型,该类型通过指定的窗口大小(上图中的 window_size)来存储数据的历史步长(比如1就代表不存储历史记录,每次新的值都会覆盖旧的),并且可以格式化输出。另外 SmoothValue 还实现了统计中位数、均值等方法,并且能够在各进程间同步数据。

    MetricLogger 除了通过key来存储SmoothValue以外,最重要的就是其实现了一个log_every的方法,这个方法是一个生成器,用于将每个batch的数据取出(yeild),然后该方法内部会暂停在此处,待模型训练完一次迭代后再执行剩下的内容,进行各项统计,然后再yeild下一个batch的数据,暂停在那里,以此重复,直至所有batch都训练完。这种方式在其它项目中比较少见,感兴趣的炼丹者们可以一试,找些新鲜感~

    train_one_epoch(ii)

    在计算出loss后,若采用了分布式训练,那么就在各个进程间进行同步。另外,若梯度溢出了,那么此时会产生梯度爆炸,于是就直接结束训练。

    train_one_epoch(iii)

    于是,为避免梯度爆炸,在训练过程中,对梯度进行裁剪,裁剪方式有很多种,可以直接对梯度值处理,这里的方式是对梯度的范式做截断,默认是第二范式,即所有参数的梯度平方和开方后与一个指定的最大值(下图中max_norm)相比,若比起大,则按比例对所有参数的梯度进行缩放。

    train_one_epoch(iv)

    最后,将 MetricLogger 统计的各项数据在进程间进行同步,同时返回它们的历史均值,对于这个历史均值的解释见下图注释。

    train_one_epoch(v)

    关于 MetricLogger 和 SmoothValue 的具体实现这里就不作解析了,这只是作者的个人喜好,用于训练过程中数据的记录与展示,和模型的工作原理及具体实现无关,大家如果想要将 DETR 用到自己的项目上,完全可以不care这部分。对于 MetricLogger 和 SmoothValue 的这种做法,我们可以学习下里面的技巧,抽象地继承,而不必生搬硬套。


    数据处理

    先来讲解下第一部分中拉下的collate_fn(),它的作用是将一个batch的数据重新组装为自定义的形式,输入参数batch就是原始的一个batch数据,通常在Pytorch中的Dataloader中,会将一个batch的数据组装为((data1, label1), (data2, label2), ...)这样的形式,于是第一行代码的作用就是将其变为[(data1, data2, data3, ...), (label1, label2, label3, ...)]这样的形式,然后取出batch[0]即一个batch的图像输入到nested_tensor_from_tensor_list()方法中进行处理,最后将返回结果替代原始的这一个batch图像数据。

    collate_fn

    接着来看看nested_tensor_from_tensor_list()是如何操作的。首先,为了能够统一batch中所有图像的尺寸,以便形成一个batch,我们需要得到其中的最大尺度(在所有维度上),然后对尺度较小的图像进行填充(padding),同时设置mask以指示哪些部分是padding得来的,以便后续模型能够在有效区域内去学习目标,相当于加入了一部分先验知识。

    nested_tensor_from_tensor_list

    下图演示了如何得到batch中每张图像在每个维度上的最大值,代码已经show得很明白了,CW无需多言。

    _max_by_axis

    构建数据集使用的是 build_dataset() 这个方法,该方法位于 datasets/__init__.py 文件。方法内部根据用户参数来构造用于目标检测/全景分割的数据集。image_set 是一个字符类型的参数,代表要构造的是训练集还是验证集。

    build_dataset

    针对目标检测任务,我们来看看 build_coco() 这个方法的内容,该方法位于 datasets/coco.py

    build

    这个方法首先检查数据文件路径的有效性,然后构造一个字典类型的 PATHS 变量来映射训练集与验证集的路径,最后实例化一个 CocoDetection() 对象,CocoDetection 这个类继承了torchvision.datasets.CocoDetection。

    CocoDetection

    在类的初始化方法中,首先调用父类的初始化方法,将图像文件及标注文件的路径传进去。transforms 是用于数据增强的方法;根据名字来看,ConvertCocoPolysToMask() 这个对象是将数据标注的多边形坐标转换为掩码,但其实不仅仅是这样,或者说不一定是这样,因为需要根据传进去的参数 return_masks 来确定。

    另外,需要提下COCO数据集中标注字段annotation的格式,对于目标检测任务,其格式如下:

    annotation

    当 "iscrowd" 字段为0时,segmentation就是polygon的形式,比如这时的 "segmentation" 的值可能为 [[510.66, 423.01, 511.72, 420.03, 510.45......], ..],其中是一个个polygon即多边形,,这些数按序两两组成多边形各个点的横、纵坐标,也就是说,表示polygon的list中如果有n个数(必定是偶数),那么就代表了 n/2 个点坐标。

    至于取数据用到的 __getitem__ 方法,首先也是调用父类的这个方法获得图像和对应的标签,然后 prepare 就是调用 ConvertCocoPolysToMask() 这个对象对图像和标签进行处理,之后若有指定数据增强,则进一步进行对应的处理,最后返回这一系列处理后的图像和对应的标签。

    现在我们来看看 ConvertCocoPolysToMask 这个类内部究竟玩了些什么东东。

    ConvertCocoPolysToMask(i)

    这里的 target 是一个list,其中包含了多个字典类型的annotation,每个annotation的格式如上一部分的图中所示。这里将 "iscrowd" 为1的数据(即一组对象,如一群人)过滤掉了,仅保留标注为单个对象的数据。

    另外这里对bbox的形式做了转换,将"xywh"转换为"x1y1x2y2"的形式,并且将它们控制图像尺寸范围内。

    ConvertCocoPolysToMask(ii)

    通过上图可以了解到,若传进来的 return_masks 值不为True,那么实质上是没有做 "convert_poly_to_mask" 这个操作的,这也是为何我在上述提到 ConvertCocoPolysToMask() 这个对象的实际操作可能和其命名有所差异。

    下图中,keep 代表那些有效的bbox,即左上角坐标小于右下角坐标那些,过滤掉无效的那批。

    ConvertCocoPolysToMask(iii)

    在进行完处理和过滤操作后,更新annotation里各个字段的值,同时新增 "orig_size" 和 "size" 两个 key,最后返回处理后的图像和标签。

    ConvertCocoPolysToMask(iv)

    综上所述,ConvertCocoPolysToMask() 仅在传入的参数 return_masks 为True时做了将多边形转换为掩码的操作,该对象的主要工作其实是过滤掉标注为一组对象的数据,以及筛选掉bbox坐标不合法的那批数据。

    现在我们来看看 convert_coco_poly_to_mask() 这个方法即将多边形坐标转换为掩码是如何操作的。

    convert_coco_poly_to_mask

    该方法中调用的 frPyObjects 和 decode 都是 coco api(pycocotools)中的方法,将每个多边形结合图像尺寸解码为掩码,然后将掩码增加至3维(若之前不足3维)。

    这里有个实现上的细节——为何要加一维呢?因为我们希望的是这个mask能够在图像尺寸范围(h, w)中指示每个点为0或1,在解码后,mask的shape应该是 (h,w),加一维变为 (h,w,1),然后在最后一个维度使用any()后才能维持原来的维度即(h,w);如果直接在(h,w)的最后一维使用any(),那么得到的shape会是(h,),各位可以码码试试。

    最后,将一个个多边形转换得到的掩码添加至列表,堆叠起来形成张量后返回。

    在本系列第一篇文中我就提到过,说 DETR 的整体工作很solid,没有使用骚里骚气的数据增强,那么我们就来看看它究竟在数据增强方面做了啥。

    make_coco_transforms(i)

    可以看到,真的是很“老土”!就是归一化、随机反转、缩放、裁剪,除此之外,没有了,可谓大道至简~

    make_coco_transforms(ii)

    另外,提及下,上图中的 T 是项目中的 datatsets/transforms.py 模块,以上各个数据增强的方法在该模块中的实现和 torchvision.transforms 中的差不多,其中ToTensor()会将图像的通道维度排列在第一个维度,并且像素值归一化到0-1范围内;而Normalize()则会根据指定的均值和标准差对图像进行归一化,同时将标签的bbox转换为c_{x} c_{y} wh形式后归一化到0-1,此处不再进行解析,感兴趣的可以去参考源码。


    @最后

    通常,很多项目在数据处理部分都会相对复杂,一方面固然是因为数据处理好了模型才能进行有效训练与学习,而另一方面则是为了适应任务需求而“不得已”处理成这样,其中还可能会使用到一些算法技巧,但是在 DETR中,真的太简单了,coco api 几乎搞定了一切,然后搞几个超级老土的 data augmentation,完事,666!

    相关文章

      网友评论

          本文标题:源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理

          本文链接:https://www.haomeiwen.com/subject/iccwfktx.html