美文网首页
Focal Loss损失函数(pytorch实现)

Focal Loss损失函数(pytorch实现)

作者: 小黄不头秃 | 来源:发表于2023-06-29 15:59 被阅读0次

Focal Loss 是何凯明大神提出的一个新的损失函数,其基于交叉熵损失函数做了一些修改。Focal Loss源自ICCV2017的一篇论文:Best student paper——Focal Loss for Dense Object Detection。

论文下载链接为:Lin_Focal_Loss_for_ICCV_2017_paper.pdf

Focal Loss的提出主要是解决机器视觉领域中的样本数量不均衡的问题,它还强调了样本的难易性。当数据集中的某一类的样本过少,其训练难度也相对较大,那么Focal Loss就是为了解决这个问题。

一、Focal Loss 损失函数

首先我们看一下,交叉熵损失函数的公式:

这里y为真实样本的概率分布,p为预测的概率分布。这里为了简化推导我们可以重新定义pt:

所以,上述的交叉熵损失函数就变成了,如下形式:

并且有人提出带权重的交叉熵损失函数,其公式如下:

这个由人为设定的\alpha虽然能够解决一定的正负样本不均衡问题,但是其还是没有办法让神经网络去区分样本的难易程度。Focal Loss认为,数据的难易程度其实是由模型来进行判断的,也就是说我们可以将模型的输出作为数据的难易程度判断的标准。于是大佬们设计出了如下的Focal Loss:

使用(1-p_t)^γ来作为难易程度的代表,并且我们可以发现当γ=0时,Focal Loss就等于原来的交叉熵。

二、pytorch代码实现

"""
以二分类任务为例
"""
from torch import nn 
import torch 

class FocalLoss(nn.Module):
    def __init__(self, gama=1.5, alpha=0.25, weight=None, reduction="mean") -> None:
        super().__init__() 
        self.loss_fcn = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
        self.gama = gama 
        self.alpha = alpha 

    def forward(self, pre, target):
        logp = self.loss_fcn(pre, target)
        p = torch.exp(-logp) 
        loss = (1-p)**self.gama * self.alpha * logp
        return loss.mean()

相关文章

网友评论

      本文标题:Focal Loss损失函数(pytorch实现)

      本文链接:https://www.haomeiwen.com/subject/wqmiydtx.html