美文网首页
BCEWithLogitsLoss参数weight

BCEWithLogitsLoss参数weight

作者: 三方斜阳 | 来源:发表于2021-10-07 09:00 被阅读0次
1. weight:
  • a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.

    就是给出weight参数后,会将其shape和input的shape相匹配。回忆公式:
  • 默认情况,也就是weight=None时,上述公式中的Wn=1;当weight!=None时,也就意味着我们需要为每一个样本赋予权重Wi,这样weight的shape和input一致就很好理解了。
    首先看pytorch中weight参数作用后的结果,weight就是为每一个样本加权:
import torch
import torch.nn as nn
input = torch.tensor([[-0.4089,-1.2471,0.5907],
                      [-0.4897,-0.8267,-0.7349],
                      [0.5241,-0.1246,-0.4751]])
m=nn.Sigmoid()
S_input=m(input)

target=torch.FloatTensor([[0,1,1],[0,0,1],[1,0,1]])

w = [0.1, 0.9] # 标签0和标签1的权重
weight = torch.zeros(target.shape)  # 权重矩阵
for i in range(target.shape[0]):
    for j in range(target.shape[1]):
        weight[i][j] = w[int(target[i][j])]
print(weight)

BCEWithLogitsLoss=nn.BCEWithLogitsLoss(weight=weight)
loss = BCEWithLogitsLoss(input,target)
print(loss)
loss = 0.0
for i in range(S_input.shape[0]):
    for j in range(S_input.shape[1]):
        loss += -weight[i][j] * (target[i][j] * torch.log(S_input[i][j]) + (1 - target[i][j]) * torch.log(1 - S_input[i][j]))
print(loss/(S_input.shape[0]*S_input.shape[1])) # 默认取均值

tensor([[0.1000, 0.9000, 0.9000],
        [0.1000, 0.1000, 0.9000],
        [0.9000, 0.1000, 0.9000]])
tensor(0.4711)
tensor(0.4711)
  • pytorch官方的代码和自己实现的计算出的损失一致,再次说明BCEWithLogitsLoss的weight权重会分别对应的作用在每一个样本上。

相关文章

网友评论

      本文标题:BCEWithLogitsLoss参数weight

      本文链接:https://www.haomeiwen.com/subject/wpvonltx.html