CVPR 2018 | Cascaded Pyramid Network for Multi-Person Pose Estimation
https://github.com/chenyilun95/tf-cpn
1.文章概述
本文提出了一种级联金字塔网络CPN,该网络由全局金字塔网络(GlobalNet)和利用在线难例挖掘机制的精馏网络(RefineNet)组成。GlobalNet是一个特征金字塔网络,可以成功地定位“简单”的关键点(如眼睛和手),但可能无法准确识别被遮挡或看不见的关键点。RefineNet尝试通过整合来自GlobalNet的所有尺度的特征,以及在线难例关键点挖掘损失来处理“复杂”关键点的精确定位。
如下图所示,Cascaded Pyramid Network主要由两部分组成:GlobalNet和RefineNet。
2.GlobalNet
如下图所示,GlobalNet以ResNet为基础框架,使用与FPN相似的特征金字塔结构来估计关键点。每一个特征尺度多会输出对应的关键点信息。作者称这种结构为GlobalNet。
基于ResNet主干网的GlobalNet可以有效地定位眼睛等关键点,但可能无法准确定位髋部位置。像髋部这样的关键点定位通常需要更多的上下文信息和处理,而不是附近的外观特征。在许多情况下,单凭一个Global网络很难直接识别这些关键点。基于此作者在此后接了一个RefineNet。
3.RefineNet
如下图所示,在GlobalNet生成的特征金字塔表示的基础上,作者附加了一个细化网络来处理难例关键点。为了提高信息传输的效率和保持信息传输的完整性,RefineNet将不同的层次的特征进行上采样后concat。与堆叠沙漏的细分策略不同,RefineNet将所有金字塔特性串联起来,而不是简单地使用沙漏模块末尾的上采样特性。
随着网络训练的不断深入,网络对大多数简单关键点的关注越来越多,而对被遮挡和硬关键点的关注越来越少。我们应该确保这两类关键点之间的回归平衡。因此,在RefineNet训练中,根据训练损失来明确地在线选择难例关键点,并仅从所选关键点反向传播梯度,该方法被称为OHKM。如下代码所示为OHKM损失函数,从中可以看出该函数就是对MSE输出的结果进行了排序,并筛选其中难例部分进行重点回归。
class JointsOHKMMSELoss(nn.Module):
def __init__(self, use_target_weight, topk=8):
super(JointsOHKMMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='none')
self.use_target_weight = use_target_weight
self.topk = topk
def ohkm(self, loss):
ohkm_loss = 0.
for i in range(loss.size()[0]):
sub_loss = loss[i]
topk_val, topk_idx = torch.topk(
sub_loss, k=self.topk, dim=0, sorted=False
)
tmp_loss = torch.gather(sub_loss, 0, topk_idx)
ohkm_loss += torch.sum(tmp_loss) / self.topk
ohkm_loss /= loss.size()[0]
return ohkm_loss
def forward(self, output, target, target_weight):
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = []
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze()
heatmap_gt = heatmaps_gt[idx].squeeze()
if self.use_target_weight:
loss.append(0.5 * self.criterion(
heatmap_pred.mul(target_weight[:, idx]),
heatmap_gt.mul(target_weight[:, idx])
))
else:
loss.append(
0.5 * self.criterion(heatmap_pred, heatmap_gt)
)
loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss]
loss = torch.cat(loss, dim=1)
return self.ohkm(loss)
4.结果展示
下图展示了不同阈值的NMS策略的性能,结果显示Soft-NMS表现出了最优性能。
下图结果显示了OHKM,在线难例挖掘的有效性。
最终的结果也显示了本文提出的策略的有效性,但总的来说本文提出的OHKM反而被其他SOTA算法广泛使用。
网友评论