美文网首页
2020-09-14 DiceLoss

2020-09-14 DiceLoss

作者: Joyner2018 | 来源:发表于2020-09-14 22:21 被阅读0次

    !/usr/bin/env python

    -- coding: utf-8 --

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np

    def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
    input: A tensor of shape [N, 1, *]
    num_classes: An int of number of class
    Returns:
    A tensor of shape [N, num_classes, *]
    """
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu(), 1)

    return result
    

    class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
    smooth: A float number to smooth loss, and avoid NaN error, default: 1
    p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
    predict: A tensor of shape [N, *]
    target: A tensor of shape same with predict
    reduction: Reduction method to apply, return mean over batch if 'mean',
    return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
    Loss tensor according to arg reduction
    Raise:
    Exception if unexpected reduction
    """
    def init(self, smooth=1, p=2, reduction='mean'):
    super(BinaryDiceLoss, self).init()
    self.smooth = smooth
    self.p = p
    self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)
    
        num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
    
        loss = 1 - num / den
    
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))
    

    class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input
    Args:
    weight: An array of shape [num_classes,]
    ignore_index: class index to ignore
    predict: A tensor of shape [N, C, *]
    target: A tensor of same shape with predict
    other args pass to BinaryDiceLoss
    Return:
    same as BinaryDiceLoss
    """
    def init(self, weight=None, ignore_index=None, **kwargs):
    super(DiceLoss, self).init()
    self.kwargs = kwargs
    self.weight = weight
    self.ignore_index = ignore_index

    def forward(self, predict, target):
        assert predict.shape == target.shape, 'predict & target shape do not match'
        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)
    
        for i in range(target.shape[1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss
    
        return total_loss/target.shape[1]

    相关文章

      网友评论

          本文标题:2020-09-14 DiceLoss

          本文链接:https://www.haomeiwen.com/subject/pibtyktx.html