Focal loss 是在 Focal Loss for Dense Object Detection 中为了解决数据类别不平衡而提出的。
它是一个动态缩放的交叉熵损失,其中缩放因子随着正确类的置信度的增加衰减到零。因此可以在训练过程中自动降低简单示例的贡献,并快速将模型集中在困难示例上。
Focal loss 定义为:
其中 为模型对样本的正确类别估计概率。
当误分类时,很小,调制因子接近1,loss 不受影响。当 →1,因子趋近于0,分类良好的样本的损失权重下降,如上图。
当 , Focal loss 相当于 Cross Entropy loss。实际应用中一般取 。
另一种平衡版本的 focal loss, 在论文的实验中能获得更好的结果:
pytorch 实现:
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = -1,
gamma: float = 2,
reduction: str = "none",
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Loss tensor with the reduction option applied.
"""
inputs = inputs.float() # (B, C)
targets = targets.float() # (B, C)
p = torch.sigmoid(inputs) # (B, C)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") # (B, C)
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma) # (B, C)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) # # (B, C)
loss = alpha_t * loss # (B, C)
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
网友评论