美文网首页pytorch基础
binary_cross_entropy_with_logits

binary_cross_entropy_with_logits

作者: Cat丹 | 来源:发表于2020-03-30 17:16 被阅读0次

    binary_cross_entropy_with_logits

    • 接受任意形状的输入,target要求与输入形状一致。切记:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。
    • 计算其实就是交叉熵,不过输入不要求在0,1之间,该函数会自动添加sigmoid运算
    • 默认的reduction方式为mean
    • 下面的实现代码中loss2是根据公式写的,loss3来源于网上的一段代码link
    import torch
    import torch.nn as nn
    from torch.nn import functional as F
    
    gt=torch.zeros(10)
    # gt[3]=1
    pred=torch.rand(10)
    print(pred)
    print(gt)
    loss=F.binary_cross_entropy_with_logits(pred,gt)
    loss2=torch.zeros(10)
    for i in range(len(gt)):
        loss2[i]=-gt[i]*torch.log(torch.sigmoid(pred[i]))-(1-gt[i])*torch.log(1-torch.sigmoid(pred[i]))
    max_val = (-pred).clamp(min=0)
    loss3 = pred - pred * gt + max_val + ((-max_val).exp() + (-pred - max_val).exp()).log()
    print('bce loss',loss)
    print('cal loss1',loss2.mean(),loss2.sum())
    print('cal loss2',loss3.mean(),loss3.sum())
    

    相关文章

      网友评论

        本文标题:binary_cross_entropy_with_logits

        本文链接:https://www.haomeiwen.com/subject/vlgvuhtx.html