美文网首页
loss函数之triplet loss

loss函数之triplet loss

作者: ltochange | 来源:发表于2021-06-18 23:43 被阅读0次

    不同于交叉熵损失仅仅考虑样本与类别标签之间误差,triplet loss关注样本与其他样本之间距离。来自论文Learning local feature descriptors with triplets and shallow convolutional neural networks

    对于包含N个样本的batch数据 D(a, p,n)。 第i个样本对应的loss,如下:
    l_{i}=\max \left\{d\left(a_{i}, p_{i}\right)-d\left(a_{i}, n_{i}\right)+\operatorname{margin}, 0\right\}
    其中,apn,分别代表锚点,正例(与锚点同类)和负例(与锚点不同类)。距离函数d, 用于度量锚点与正例负例之间的距离。margin是人为设置的常数。最小化损失函数,使得锚点与正例的距离越小,与负例的距离越大。

    由以上公式可知,

    (1) 当 d\left(a_{i}, p_{i}\right)-d\left(a_{i}, n_{i}\right)+\operatorname{margin}<0,即 d\left(a_{i}, n_{i}\right) > d\left(a_{i}, p_{i}\right)+\operatorname{margin}, 该样本对应的loss为0。

    此时,锚点和负例的距离大于锚点和正例的距离,并且差值大于margin。 对于这样的锚点被认为是易分类样本,直接忽略其带来的误差,从而加速计算。

    (2) 当 d\left(a_{i}, p_{i}\right)-d\left(a_{i}, n_{i}\right)+\operatorname{margin}>0, 该样本对应的lossd\left(a_{i}, p_{i}\right)-d\left(a_{i}, n_{i}\right)+\operatorname{margin}, 分为两种情况:

    • d\left(a_{i}, p_{i}\right)+\operatorname{margin}>d\left(a_{i}, n_{i}\right)>d\left(a_{i}, p_{i}\right) , 对应难分类样本。

    • d\left(a_{i}, p_{i}\right)+\operatorname{margin}>d\left(a_{i}, p_{i}\right)>d\left(a_{i}, n_{i}\right) ,对应非常难分类样本,容易误分类

    TripletMarginLoss

    class TripletMarginLoss(_Loss):
        __constants__ = ['margin', 'p', 'eps', 'swap', 'reduction']
        def __init__(self, margin=1.0, p=2., eps=1e-6, swap=False, size_average=None,
                     reduce=None, reduction='mean'):
            super(TripletMarginLoss, self).__init__(size_average, reduce, reduction)
            self.margin = margin
            self.p = p
            self.eps = eps
            self.swap = swap
        def forward(self, anchor, positive, negative):
            return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
                                         eps=self.eps, swap=self.swap, reduction=self.reduction)
    

    pytorch中通过torch.nn.TripletMarginLoss类实现,也可以直接调用F.triplet_margin_loss 函数。size_averagereduce已经弃用。reduction有三种取值mean, sum, none,对应不同的返回\ell(a, p, n) 。 默认为mean,对应于一般情况下整体loss的计算。

    L=\left\{l_{1}, \ldots, l_{N}\right\}

    \ell(a, p, n)=\left\{\begin{array}{ll} L, & \text { if reduction }=\text { 'none' }\\ \frac{1}{N} \sum_{i=1}^{N} l_{i}, & \text { if reduction}=\text { 'mean' } \\ \sum_{i=1}^{N} l_{i}, & \text { if reduction }=\text { 'sum' }\end{array}\right.

    该类默认使用如下距离函数,p默认为2,对应欧式距离。

    d\left(x_{i}, y_{i}\right)=\left\|\mathbf{x}_{i}-\mathbf{y}_{i}\right\|_{p}

    pytorch也有计算该距离的函数torch.nn.PairwiseDistance

    例子:

    import torch
    import torch.nn as nn
    torch.manual_seed(20)
    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    anchor = torch.randn(100, 128, requires_grad=True)
    positive = torch.randn(100, 128, requires_grad=True)
    negative = torch.randn(100, 128, requires_grad=True)
    output = triplet_loss(anchor, positive, negative)
    print(output.item())
    # none
    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, reduction="none")
    output = triplet_loss(anchor, positive, negative)
    print(output.size())
    

    结果:

    1.1951137781143188
    torch.Size([100])
    

    TripletMarginWithDistanceLoss

    该loss函数与 TripletMarginLoss功能基本一致,只不过可以定制化的传入不同的距离函数。当传入的距离函数是torch.nn.PairwiseDistance时,两者完全一致

    例子:

    import torch
    import torch.nn as nn
    torch.manual_seed(20)
    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    anchor = torch.randn(100, 128, requires_grad=True)
    positive = torch.randn(100, 128, requires_grad=True)
    negative = torch.randn(100, 128, requires_grad=True)
    triplet_loss = nn.TripletMarginWithDistanceLoss(reduction="mean", distance_function=nn.PairwiseDistance())
    output = triplet_loss(anchor, positive, negative)
    print(output.item())
    
    triplet_loss = nn.TripletMarginWithDistanceLoss(reduction="none", distance_function=nn.PairwiseDistance())
    output = triplet_loss(anchor, positive, negative)
    print(output.size())
    

    结果和TripletMarginLoss一致:

    1.1951137781143188
    torch.Size([100])
    

    使用自定义的距离函数:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    torch.manual_seed(20)
    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    anchor = torch.randn(100, 128, requires_grad=True)
    positive = torch.randn(100, 128, requires_grad=True)
    negative = torch.randn(100, 128, requires_grad=True)
    
    
    # Custom Distance Function
    def l_infinity(x1, x2):
        return torch.max(torch.abs(x1 - x2), dim=1).values
    triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
    output = triplet_loss(anchor, positive, negative)
    print(output.item())
    
    # Custom Distance Function (Lambda)
    triplet_loss = nn.TripletMarginWithDistanceLoss(
        distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
    output = triplet_loss(anchor, positive, negative)
    print(output.item())
    

    结果:

    1.529929518699646
    1.0007251501083374
    

    相关文章

      网友评论

          本文标题:loss函数之triplet loss

          本文链接:https://www.haomeiwen.com/subject/defryltx.html