美文网首页
Focal Loss理论及PyTorch实现

Focal Loss理论及PyTorch实现

作者: 几时见得清梦 | 来源:发表于2019-10-24 10:33 被阅读0次

一、基本理论

  1. 采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升。
  2. alpha 与每个类别在训练数据中的频率有关。
  3. F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同。
    F.nll_loss中实现了对于target的one-hot encoding,将其编码成与input shape相同的tensor,然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production。
基于alpha=1采用不同的gamma值进行实验的结果
  1. focal loss解决了什么问题?
    (1)不同类别不均衡
    (2)难易样本不均衡

  2. 在retinanet中,除了使用呢focal loss外,还对初始化做了特殊处理,具体是怎么做的?

在retinanet中,对 classification subnet 的最后一层conv设置它的偏置b为:

b=−log((1−π)/π)

π代表先验概率,就是类别不平衡中个数少的那个类别占总数的百分比,在检测中就是代表object的anchor占所有anchor的比重,论文中设置的为0.01。

二、公式

标准的Cross Entropy 为:[图片上传失败...(image-286df1-1571884440851)]

Focal Loss 为:[图片上传失败...(image-460db1-1571884440851)]

其中,[图片上传失败...(image-d6c655-1571884440851)]

关于Focal Loss的前向与后向推导见:知乎:Focal Loss 的前向与后向公式推导

三、代码实现

一、来自Kaggle的实现(基于二分类交叉熵实现)

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

二、来自知乎大佬的实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)


        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

参考

  1. 知乎:Focal Loss 的Pytorch 实现以及实验
  2. Kaggle:A Pytorch implementation of Focal Loss
  3. GitHub:CoinCheung/pytorch-loss
  4. GitHub:Hsuxu/Loss_ToolBox-PyTorch
  5. 知乎:focal loss理解与初始化偏置b设置解释
  6. 个人博客:对focal loss中bias init的解释很好

相关文章

  • Focal Loss理论及PyTorch实现

    一、基本理论 采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升。 al...

  • Focal loss 以及 pytorch实现

    Focal loss 是在 Focal Loss for Dense Object Detection[https...

  • Focal loss特点和不均衡问题

    focal loss来源于论文《Focal loss for dense object detection》, ...

  • 2019-03-27

    Focal Loss

  • ICCV2017:Focal Loss for Dense Ob

    这篇有关Focal Loss的paper是何恺明大神提出的又一经典paper,除了提出Focal Loss还提出了...

  • Focal loss函数及代码

    一、Focal loss损失函数 Focal Loss的引入主要是为了解决**难易样本数量不平衡****(注意,有...

  • Focal Loss 原理及实践

    1 关于Focal Loss Focal Loss 是一个在交叉熵(CE)基础上改进的损失函数,来自ICCV201...

  • Focal Loss

    一、Focal Loss定义   关于指标αt的选择问题,论文中描述如下:  In practice,α may ...

  • focal loss

    https://blog.csdn.net/u014380165/article/details/77019084

  • Focal loss

    Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单...

网友评论

      本文标题:Focal Loss理论及PyTorch实现

      本文链接:https://www.haomeiwen.com/subject/sskoectx.html