
Focal Loss 是何凯明大神提出的一个新的损失函数,其基于交叉熵损失函数做了一些修改。Focal Loss源自ICCV2017的一篇论文:Best student paper——Focal Loss for Dense Object Detection。
论文下载链接为:Lin_Focal_Loss_for_ICCV_2017_paper.pdf。
Focal Loss的提出主要是解决机器视觉领域中的样本数量不均衡的问题,它还强调了样本的难易性。当数据集中的某一类的样本过少,其训练难度也相对较大,那么Focal Loss就是为了解决这个问题。
一、Focal Loss 损失函数
首先我们看一下,交叉熵损失函数的公式:

这里y为真实样本的概率分布,p为预测的概率分布。这里为了简化推导我们可以重新定义pt:

所以,上述的交叉熵损失函数就变成了,如下形式:

并且有人提出带权重的交叉熵损失函数,其公式如下:

这个由人为设定的虽然能够解决一定的正负样本不均衡问题,但是其还是没有办法让神经网络去区分样本的难易程度。Focal Loss认为,数据的难易程度其实是由模型来进行判断的,也就是说我们可以将模型的输出作为数据的难易程度判断的标准。于是大佬们设计出了如下的Focal Loss:

使用来作为难易程度的代表,并且我们可以发现当
时,Focal Loss就等于原来的交叉熵。

二、pytorch代码实现
"""
以二分类任务为例
"""
from torch import nn
import torch
class FocalLoss(nn.Module):
def __init__(self, gama=1.5, alpha=0.25, weight=None, reduction="mean") -> None:
super().__init__()
self.loss_fcn = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)
self.gama = gama
self.alpha = alpha
def forward(self, pre, target):
logp = self.loss_fcn(pre, target)
p = torch.exp(-logp)
loss = (1-p)**self.gama * self.alpha * logp
return loss.mean()
网友评论