美文网首页
focal loss

focal loss

作者: 三方斜阳 | 来源:发表于2021-10-07 20:39 被阅读0次

    先验知识:交叉熵 - 简书 (jianshu.com)

    理解:

    针对类别不均衡问题,作者提出一种新的损失函数:focal loss,这个损失函数是在标准交叉熵损失基础上修改得到的。这个函数可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。

    • Focal loss是在交叉熵损失函数基础上进行的修改,首先回顾二分类交叉上损失:
    • y^'是经过激活函数的输出,所以在0-1之间。可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。此时的损失函数在大量简单样本的迭代过程中比较缓慢且可能无法优化至最优。那么Focal loss是怎么改进的呢?


    • 首先在原有的基础上加了一个因子,其中gamma>0使得减少易分类样本的损失。使得更关注于困难的、错分的样本。

    例如gamma为2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的gamma次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。这样减少了简单样本的影响,大量预测概率很小的样本叠加起来后的效应才可能比较有效。

    • 此外,加入平衡因子alpha,用来平衡正负样本本身的比例不均:文中alpha取0.25,即正样本要比负样本占比小,这是因为负例易分。

      只添加alpha虽然可以平衡正负样本的重要性,但是无法解决简单与困难样本的问题。

    • gamma调节简单样本权重降低的速率,当gamma为0时即为交叉熵损失函数,当gamma增加时,调整因子的影响也在增加。实验发现gamma为2是最优。

    focal loss实现

    • 简单二分类:
    import torch
    
    class FocalLoss:
        def __init__(self, alpha_t=None, gamma=0):
            """
            :param alpha_t: A list of weights for each class
            :param gamma:
            """
            self.alpha_t = torch.tensor(alpha_t) if alpha_t else None
            self.gamma = gamma
    
        def __call__(self, outputs, targets):
            if self.alpha_t is None and self.gamma == 0:
                focal_loss = torch.nn.functional.cross_entropy(outputs, targets)
    
            elif self.alpha_t is not None and self.gamma == 0:
                if self.alpha_t.device != outputs.device:
                    self.alpha_t = self.alpha_t.to(outputs)
                focal_loss = torch.nn.functional.cross_entropy(outputs, targets,weight=self.alpha_t)
    
            elif self.alpha_t is None and self.gamma != 0:
                ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none')
                p_t = torch.exp(-ce_loss)
                focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean()
    
            elif self.alpha_t is not None and self.gamma != 0:
                if self.alpha_t.device != outputs.device:
                    self.alpha_t = self.alpha_t.to(outputs)
                ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none')
                p_t = torch.exp(-ce_loss)
                ce_loss = torch.nn.functional.cross_entropy(outputs, targets,weight=self.alpha_t, reduction='none')
                focal_loss = ((1 - p_t) ** self.gamma * ce_loss).mean()  # mean over the batch
    
            return focal_loss
    import torch.nn.functional as F
    import torch.nn as nn
    if __name__ == '__main__':
        outputs = torch.tensor([[2, 1.],
                    [2.5, 1]], device='cuda')
        targets = torch.tensor([0, 1], device='cuda')
        print(torch.nn.functional.softmax(outputs, dim=1))
    
        fl= FocalLoss([0.5, 0.5], 2)
        loss = F.cross_entropy(outputs, targets)
        print(loss)
        print(fl(outputs, targets))
    
    • 多分类类似
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    
    
    from typing import List
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class FocalLoss(nn.Module):
        """
        Focal loss(https://arxiv.org/pdf/1708.02002.pdf)
        Shape:
            - input: (N, C)
            - target: (N)
            - Output: Scalar loss
        Examples:
            >>> loss = FocalLoss(gamma=2, alpha=[1.0]*7)
            >>> input = torch.randn(3, 7, requires_grad=True)
            >>> target = torch.empty(3, dtype=torch.long).random_(7)
            >>> output = loss(input, target)
            >>> output.backward()
        """
        def __init__(self, gamma=0, alpha: List[float] = None, reduction="none"):
            super(FocalLoss, self).__init__()
            self.gamma = gamma
            self.alpha = alpha
            if alpha is not None:
                self.alpha = torch.FloatTensor(alpha)
            self.reduction = reduction
    
        def forward(self, input, target):
            # [N, 1]
            m=nn.CrossEntropyLoss()
            print(m(input,target))
            target = target.unsqueeze(-1)
            print("target1:",target)
            # [N, C]
            pt = F.softmax(input, dim=-1)
            print('pt1:',pt)
            logpt = F.log_softmax(input, dim=-1)
            print('logpt1:',logpt)
            # [N]
            print('zhangyi:',pt.gather(1, target))
            pt = pt.gather(1, target).squeeze(-1)
            print('pt2:',pt)
            logpt = logpt.gather(1, target).squeeze(-1)
            print('logpt2:',logpt)
    
            if self.alpha is not None:
                # [N] at[i] = alpha[target[i]]
                print("target.squeeze(-1)",target.squeeze(-1))
                at = self.alpha.gather(0, target.squeeze(-1))
                print('at1',at)
                print('logpt3',logpt)
                logpt = logpt * at
                print('logpt3',logpt)
    
            loss = -1 * (1 - pt) ** self.gamma * logpt
            if self.reduction == "none":
                return loss
            if self.reduction == "mean":
                return loss.mean()
            return loss.sum()
    
        @staticmethod
        def convert_binary_pred_to_two_dimension(x, is_logits=True):
            """
            Args:
                x: (*): (log) prob of some instance has label 1
                is_logits: if True, x represents log prob; otherwhise presents prob
            Returns:
                y: (*, 2), where y[*, 1] == log prob of some instance has label 0,
                                 y[*, 0] = log prob of some instance has label 1
            """
            probs = torch.sigmoid(x) if is_logits else x
            probs = probs.unsqueeze(-1)
            probs = torch.cat([1-probs, probs], dim=-1)
            logprob = torch.log(probs+1e-4)  # 1e-4 to prevent being rounded to 0 in fp16
            return logprob
    
        def __str__(self):
            return f"Focal Loss gamma:{self.gamma}"
    
        def __repr__(self):
            return str(self)
    
    loss = FocalLoss(gamma=2, alpha=[1.0]*7)
    input = torch.randn(3, 7, requires_grad=True)
    target = torch.empty(3, dtype=torch.long).random_(7)
    print(input)
    print(target)
    output = loss(input, target)
    print(output)
    # output.backward()
    

    论文:
    https://arxiv.org/pdf/1708.02002.pdf
    一些博客:
    Focal loss论文详解 - 知乎 (zhihu.com)
    Focal Loss理解 - 三年一梦 - 博客园 (cnblogs.com)

    相关文章

      网友评论

          本文标题:focal loss

          本文链接:https://www.haomeiwen.com/subject/gxannltx.html