美文网首页
关于Focal loss 和 GHM

关于Focal loss 和 GHM

作者: _从前从前_ | 来源:发表于2019-12-30 17:15 被阅读0次

    这里直接贴一个知乎链接,可以说讲得很好了https://zhuanlan.zhihu.com/p/80594704

    在单阶段的目标检测任务中,由于前景和背景不平衡的问题导致效果不如两阶段的方法(这个观点是否正确还有待研究)。

    Focal loss

    通过两个参数,一个控制平衡前景和背景的Loss贡献,一个控制简单样本的loss贡献。在目标检测任务中,简单样本虽然产生的loss很小,但是大部分样本均是简单样本,所以梯度仍然是由简单样本主导的。Focal Loss引入了两个超参数尝试去解决了这个问题。

    GHM

    通过划分bins,将梯度标准化,去解决上述提到的问题。GHM的想法是,我们确实不应该过多关注易分样本,但是特别难分的样本(outliers,离群点)也不该关注啊!那怎么同时衰减易分样本和特别难分的样本呢?太简单了,谁的数量多衰减谁呗!那怎么衰减数量多的呢?简单啊,定义一个变量,让这个变量能衡量出一定梯度范围内的样本数量——这不就是物理上密度的概念吗?
    其中密度的估计用到了EMA,考虑了全局的样本分布,这也是focal loss不具备的。只是网络的前期是否能直接使用GHM?这也是否会导致前期训练的不稳定?
    在自己的数据集上。GHM的表现不如focal loss,还需要更多的探索。

    GHMC mmdetection代码解析

    # 注册loss函数
    @LOSSES.register_module
    class GHMC(nn.Module):
        """GHM Classification Loss.
    
        Details of the theorem can be viewed in the paper
        "Gradient Harmonized Single-stage Detector".
        https://arxiv.org/abs/1811.05181
    
        Args:
            bins (int): Number of the unit regions for distribution calculation.
            momentum (float): The parameter for moving average.
            use_sigmoid (bool): Can only be true for BCE based loss now.
            loss_weight (float): The weight of the total GHM-C loss.
        """
    
        def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0):
            super(GHMC, self).__init__()
            self.bins = bins
            self.momentum = momentum
            edges = torch.arange(bins + 1).float() / bins
            self.register_buffer('edges', edges)
            self.edges[-1] += 1e-6
            if momentum > 0:
                acc_sum = torch.zeros(bins)
                self.register_buffer('acc_sum', acc_sum)
            self.use_sigmoid = use_sigmoid
            if not self.use_sigmoid:
                raise NotImplementedError
            self.loss_weight = loss_weight
    
        def forward(self, pred, target, label_weight, *args, **kwargs):
            """Calculate the GHM-C loss.
    
            Args:
                pred (float tensor of size [batch_num, class_num]):
                    The direct prediction of classification fc layer.
                target (float tensor of size [batch_num, class_num]):
                    Binary class target for each sample.
                label_weight (float tensor of size [batch_num, class_num]):
                    the value is 1 if the sample is valid and 0 if ignored.
            Returns:
                The gradient harmonized loss.
            """
            # the target should be binary class label
            if pred.dim() != target.dim():
                target, label_weight = _expand_binary_labels(
                target, label_weight, pred.size(-1))
            target, label_weight = target.float(), label_weight.float()
            edges = self.edges
            mmt = self.momentum
            weights = torch.zeros_like(pred)
    
            # gradient length
            # sigmoid梯度计算
            g = torch.abs(pred.sigmoid().detach() - target)
            # 有效的label的位置
            valid = label_weight > 0
            # 有效的label的数量
            tot = max(valid.float().sum().item(), 1.0)
            n = 0  # n valid bins
            for i in range(self.bins):
                # 将对应的梯度值划分到对应的bin中, 0-1
                inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
                # 该bin中存在多少个样本
                num_in_bin = inds.sum().item()
                if num_in_bin > 0:
                    if mmt > 0:
                        # moment计算num bin
                        self.acc_sum[i] = mmt * self.acc_sum[i] \
                            + (1 - mmt) * num_in_bin
                        # 权重等于总数/num bin
                        weights[inds] = tot / self.acc_sum[i]
                    else:
                        weights[inds] = tot / num_in_bin
                    n += 1
            if n > 0:
                # scale系数
                weights = weights / n
    
            loss = F.binary_cross_entropy_with_logits(
                pred, target, weights, reduction='sum') / tot
            return loss * self.loss_weight
    

    相关文章

      网友评论

          本文标题:关于Focal loss 和 GHM

          本文链接:https://www.haomeiwen.com/subject/lapmoctx.html