美文网首页
Fastai Focal Loss

Fastai Focal Loss

作者: 深度学习模型优化 | 来源:发表于2019-06-30 13:14 被阅读0次

    这个是防止每次运行的时候结果都差别很大的情况,前提是你使用pytorch来实现你的想法。

    def seed_everything(seed=42):
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
    seed_everything()
    

    镇楼以后,开始看看Focal Loss.

    # Source: https://www.kaggle.com/c/human-protein-atlas-image-classification/discussion/78109
    class FocalLoss(nn.Module):
        def __init__(self, gamma=2):
            super().__init__()
            self.gamma = gamma
    
        def forward(self, logit, target):
            target = target.float()
            max_val = (-logit).clamp(min=0)
            loss = logit - logit * target + max_val + \
                   ((-max_val).exp() + (-logit - max_val).exp()).log()
    
            invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
            loss = (invprobs * self.gamma).exp() * loss
            if len(loss.size())==2:
                loss = loss.sum(dim=1)
            return loss.mean()
    

    上面这个版本是从Kaggle上下载的。下面可以看看另外的版本
    Loss(x, class) = - \alpha (1-softmax(x)[class])^{\gamma} \log(softmax(x)[class])

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
    
    class FocalLoss(nn.Module):
        r"""
            This criterion is a implemenation of Focal Loss, which is proposed in 
            Focal Loss for Dense Object Detection.
    
                Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
    
            The losses are averaged across observations for each minibatch.
    
            Args:
                alpha(1D Tensor, Variable) : the scalar factor for this criterion
                gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                       putting more focus on hard, misclassified examples
                size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                    However, if the field size_average is set to False, the losses are
                                    instead summed for each minibatch.
    
    
        """
        def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
            super(FocalLoss, self).__init__()
            if alpha is None:
                self.alpha = Variable(torch.ones(class_num, 1))
            else:
                if isinstance(alpha, Variable):
                    self.alpha = alpha
                else:
                    self.alpha = Variable(alpha)
            self.gamma = gamma
            self.class_num = class_num
            self.size_average = size_average
    
        def forward(self, inputs, targets):
            N = inputs.size(0)
            C = inputs.size(1)
            P = F.softmax(inputs)
    
            class_mask = inputs.data.new(N, C).fill_(0)
            class_mask = Variable(class_mask)
            ids = targets.view(-1, 1)
            class_mask.scatter_(1, ids.data, 1.)
            #print(class_mask)
    
    
            if inputs.is_cuda and not self.alpha.is_cuda:
                self.alpha = self.alpha.cuda()
            alpha = self.alpha[ids.data.view(-1)]
    
            probs = (P*class_mask).sum(1).view(-1,1)
    
            log_p = probs.log()
            #print('probs size= {}'.format(probs.size()))
            #print(probs)
    
            batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
            #print('-----bacth_loss------')
            #print(batch_loss)
    
    
            if self.size_average:
                loss = batch_loss.mean()
            else:
                loss = batch_loss.sum()
            return loss
    

    和OHEM对比。(还不懂OHEM,so,这里就留下个影子?以后慢慢弄)

    要看Focal Loss的详细,在这里

    相关文章

      网友评论

          本文标题:Fastai Focal Loss

          本文链接:https://www.haomeiwen.com/subject/umuhcctx.html