Self Adversarial Training for Human Pose Estimation
Official Code: pytorch
1.背景分析
由于人体的遮挡和拥挤等现象,现有的人体姿态估计网络很难解决此类情况下的准确估计,且此类现象会导致网络估计的关键点不符合正常的人体姿态,失去了人体固有的形态。比如下图中第二行图片所示,相较于第一行,很显然有部分关节是违背事实的。作者希望即使在拥挤状态下,网络预测得到关键点也能够符合关节所固有结构。基于此作者提出使用生成对抗的方式来解决这个问题。
2.自对抗网络结构
与传统的GAN模型类似,本文的模型分为两个网络,生成器和鉴别器。第一个网络生成器是一个卷积网络,生成器经过前向计算,得到一组热图,它指示每个关键点的每个位置的置信度得分。第二个网络鉴别器,具有与生成器相同的架构,但它将热图与RGB图像一起编码输入,并将其解码为一组新的热图,以便区分真实的热图和虚假的热图。本文提出的自对抗网络结果如下图所示。在最终做关键点前向推理时,会将鉴别器从整体的结果中剔除。
3.生成器
生成器主要的作用是生成准确的人体关键点信息。当然作为生成对抗中的一环,生成器最主要的功能就是能够让生成的关键点欺骗最终的鉴别器,使得鉴别器无法区分当前关键点热图是GT还是生成器生成的。因此,如下图所示,训练生成器时,其通过两部分进行优化,分别为反向传播来自生成器的损耗Lmse和来自鉴别器的对抗损耗Ladv。
整体的loss如下所示,公式1的损失Lmse目的是使得生成器最终生成的人体关键点能够更加接近标签。公式2的对抗性损失Ladv,该对抗损失的目的是使得生成器最终生成的关键点符合更加合理的姿态。更直接的说,Ladv的目的是使得生成器生成的虚假热度图能够尽可能的糊弄鉴别器,使其无法区分GT热图和虚假热图。生成对抗的过程就体现在这里。最终利用公式3所示的损失来优化生成器。其中lamda是一个超参数。
4.鉴别器
鉴别器的目标是区分输入进来的热图是GT还是生成器生成的虚假热图。鉴别器最终的训练目标就是能够把生成器生成的数据竟可能和GT区分出来。从而和生成器形成一个对抗博弈的过程。因此,如下图所示,训练鉴别器时,其通过两部分进行优化,分别为反向传播来自鉴别器的损耗Lreal和来自鉴别器的损耗Lfake。
整体的loss如下所示,公式(4.1)表示将GT热图输入鉴别器得到编码后的新热图,并计算新热图和GT热图的距离,进行Lreal损失计算。公式(4.2)表示将生成器生成的虚假热图输入鉴别器得到编码后的新热图,并计算新热图和生成器生成的虚假热图之间的距离进行Lfake损失计算。正如前述提到过的,鉴别器的目的是尽可能的将虚假热图和GT热图区分开来,也就是说鉴别器希望GT热图输入后的输出重构热图尽可能和GT接近,希望虚假热图输入后的输出重构热图尽可能和虚假热图不同。从loss上来说就是希望Lreal越来越小,希望Lfake越来越大。基于此,鉴别器的loss如公式(4.3)所示。
上述公式中的kt是用来约束鉴别器的能力,通过公式(5)约束kt能够使得网络更容易训练。正如许多论文中提到的那样,GAN不稳定且难以训练,因为鉴别器过快收敛,导致网络很容易崩溃,训练出无效的生成器。鉴别器过快收敛,从loss来分析就是:Lfake小于Lreal,生成器生成的热图足够真实以欺骗鉴别器。 此时,kt将增加,以使Lfake更具优势,从而使得鉴别器进行更多的训练才能识别生成的热图。它在Lfake上加速训练的比例取决于鉴别器落在与生成器的差距。当Lfake大于Lreal时原理类似。
对公式4进行解读:
公式4.1 输入为原始RGB图像X,GT热度图C。计算的Lreal表示鉴别器产生的结果和GT热度图之间的差别。
公式4.2输入为原始RGB图像X,生成器产生的热度图C^。计算Lfake表示鉴别器产生的结果和生成器产生热度图之间的差别。
公式4.3表示最终整个公式4,也就是鉴别器的loss的目的是最小化Lreal和Lfake,即整个优化过程要求Lreal小且Lfake大,直白的来说就是要求当输入为GT热度图时,鉴别器产生尽可能和GT相同的结果。当输入为生成器产生的热度图时,鉴别器产生尽可能和生成器不同的结果。如,如果右膝盖的信心在左膝盖附近很高,则训练有素的鉴别器将产生右膝盖的热图,该热图在左膝盖的位置具有较大的误差。由于鉴别器就像评论家一样, 它在输入热图上提供了详细的“注释”,并建议热图中的哪些部分未产生真实姿势。最终整个误差会在公式2中体现出来。而公式二会指导生成器进行进化,使得最终的生成器更好,降低整个误差。
5.算法整体流程
整体算法每一个迭代过程如下:
1.将GT热度图C,原始图像X输入到鉴别器,计算鉴别器的前向结果。为D(X,C)。同时计算鉴别器的loss,公式4.1,Lreal。
2.将原始图像X输入到生成器,计算生成器的前向结果C^。同时计算生成器loss,公式1,Lmse。
3.将虚假热度图C,原始图像X输入到鉴别器,计算鉴别器的前向结果。为D(X,C)。同时计算鉴别器的loss,公式4.2,Lfake。(累计Lreal和Lfake梯度值,并更新鉴别器参数,公式4.3)。
4.有了虚假热度图C和D(X,C),利用公式2计算对抗loss,Ladv,并更新生成器。
6.结果展示
作者在LSP和MPII两个人体关键点数据集上对上述自对抗网络进行了结果分析,从下表可以看出,利用对抗生成的方式能够有效提升模型效果,且不会增加推理时间。
网友评论