binary_cross_entropy_with_logits
- 接受任意形状的输入,target要求与输入形状一致。切记:target的值必须在[0,N-1]之间,其中N为类别数,否则会出现莫名其妙的错误,比如loss为负数。
- 计算其实就是交叉熵,不过输入不要求在0,1之间,该函数会自动添加sigmoid运算
- 默认的reduction方式为mean
- 下面的实现代码中loss2是根据公式写的,loss3来源于网上的一段代码link
import torch
import torch.nn as nn
from torch.nn import functional as F
gt=torch.zeros(10)
# gt[3]=1
pred=torch.rand(10)
print(pred)
print(gt)
loss=F.binary_cross_entropy_with_logits(pred,gt)
loss2=torch.zeros(10)
for i in range(len(gt)):
loss2[i]=-gt[i]*torch.log(torch.sigmoid(pred[i]))-(1-gt[i])*torch.log(1-torch.sigmoid(pred[i]))
max_val = (-pred).clamp(min=0)
loss3 = pred - pred * gt + max_val + ((-max_val).exp() + (-pred - max_val).exp()).log()
print('bce loss',loss)
print('cal loss1',loss2.mean(),loss2.sum())
print('cal loss2',loss3.mean(),loss3.sum())
网友评论