ECCV 2018 | Integral Human Pose Regression
https://arxiv.org/abs/1711.08229v4
现有的人体关键点解决方案多是基于heatmap进行监督检测,很少有算法直接回归关键点位置的。基于热度图进行人体关键点检测的方案拥有更强的位置监督能力,使得关键点检测的效果相较于直接回归坐标要好的多,然而基于热度图的方案在将热度图转换为人体关键点坐标的过程中如下公式所示需要求一个最大值。该操作是不可微的,这直接导致基于热度图的方案无法进行端到端的训练。因此基于热度图的方式只能学习得到热度图,最终在热度图中获取关键点位置。而通常情况下,为了节省网络计算的参数量和计算量,现有的方案中热度图的分辨率通常是要通过降采样模块在输入分辨率的基础上进行缩小,而要从缩小后的热度图中恢复出原始输入图像中关键点位置存在着不可避免的量化误差,从而可能导致关键点预测不够精准。CVPR2020仅有的两篇2Dpose的文章就着重在解决这个问题。然而早在2018年的ECCV中,本文就提出了一种无参数的方案,积分回归。该方法结合了基于热度图和回归的优点,通过端到端学习的方式,有效避免了基于热图学习的弊端。且该方法可以与任何基于热度图的方法进行结合,嵌套到现有的人体关键点解决方案中。
如下公式所示,为传统的基于热度图计算关键点位置的公式。正如上面提到的,利用该公式进行计算会有诸多缺点。

基于此本文提出了一种积分回归的方式,如下公式1所示,该方式又被称为soft-argmax,积分回归的方式通过归一化二维热度图并利用连续积分的方式对热度空间进行求和,最终求得的值即为当前热度图中关键点的位置。该方法有效整合了基于热度图和回归的优点。该公式的离散表示方式如下图中的3所示。又由于归一化要求热度图中的值为非负,且和为1。因此使用下图中的公式2对热度图进行建模。

根据上述公式,其对应的代码如下所示:
import torch
import torch.nn as nn
from torch.nn import functional as F
class SoftArgmax2D(nn.Module):
"""
Creates a module that computes Soft-Argmax 2D of a given input heatmap.
Returns the index of the maximum 2d coordinates of the give map.
:param beta: The smoothing parameter.
:param return_xy: The output order is [x, y].
"""
def __init__(self, beta: int = 100, return_xy: bool = False):
if not 0.0 <= beta:
raise ValueError(f"Invalid beta: {beta}")
super().__init__()
self.beta = beta
self.return_xy = return_xy
def forward(self, heatmap: torch.Tensor) -> torch.Tensor:
"""
:param heatmap: The input heatmap is of size B x N x H x W.
:return: The index of the maximum 2d coordinates is of size B x N x 2.
"""
heatmap = heatmap.mul(self.beta)
batch_size, num_channel, height, width = heatmap.size()
device: str = heatmap.device
softmax: torch.Tensor = F.softmax(
heatmap.view(batch_size, num_channel, height * width), dim=2
).view(batch_size, num_channel, height, width)
xx, yy = torch.meshgrid(list(map(torch.arange, [height, width])))
approx_x = (
softmax.mul(xx.float().to(device))
.view(batch_size, num_channel, height * width)
.sum(2)
.unsqueeze(2)
)
approx_y = (
softmax.mul(yy.float().to(device))
.view(batch_size, num_channel, height * width)
.sum(2)
.unsqueeze(2)
)
output = [approx_x, approx_y] if self.return_xy else [approx_y, approx_x]
output = torch.cat(output, 2)
return output
下图展示了基于积分回归的方式和基于热度图检测的方式,在COCO和MPII两个数据集上的性能比对。通过指标比对也说明了积分回归的优势。

网友评论