美文网首页pytorch基础
binary_cross_entropy_with_logits

binary_cross_entropy_with_logits

作者: Cat丹 | 来源:发表于2020-03-30 17:16 被阅读0次

binary_cross_entropy_with_logits

  • 接受任意形状的输入,target要求与输入形状一致。切记:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。
  • 计算其实就是交叉熵,不过输入不要求在0,1之间,该函数会自动添加sigmoid运算
  • 默认的reduction方式为mean
  • 下面的实现代码中loss2是根据公式写的,loss3来源于网上的一段代码link
import torch
import torch.nn as nn
from torch.nn import functional as F

gt=torch.zeros(10)
# gt[3]=1
pred=torch.rand(10)
print(pred)
print(gt)
loss=F.binary_cross_entropy_with_logits(pred,gt)
loss2=torch.zeros(10)
for i in range(len(gt)):
    loss2[i]=-gt[i]*torch.log(torch.sigmoid(pred[i]))-(1-gt[i])*torch.log(1-torch.sigmoid(pred[i]))
max_val = (-pred).clamp(min=0)
loss3 = pred - pred * gt + max_val + ((-max_val).exp() + (-pred - max_val).exp()).log()
print('bce loss',loss)
print('cal loss1',loss2.mean(),loss2.sum())
print('cal loss2',loss3.mean(),loss3.sum())

相关文章

网友评论

    本文标题:binary_cross_entropy_with_logits

    本文链接:https://www.haomeiwen.com/subject/vlgvuhtx.html