美文网首页
nll_loss是什么?原来是log_softmax的损失函数!

nll_loss是什么?原来是log_softmax的损失函数!

作者: Co酱_秋乏术 | 来源:发表于2023-01-16 10:56 被阅读0次

    a = torch.Tensor([2, 3, 4])

    # 使用 pytorch 内置的 softmax 函数计算

    softmax = a.softmax(dim=0)

    print( softmax)

    # tensor([0.0900, 0.2447, 0.6652])

    log_softmax()

    使用 log_softmax 计算 softmax 的对数值,log_softmax(x) 等价于 log(softmax(x))

    最终分类如果使用 log_softmax,损失函数需要使用 nll_loss(Negative Log Likelihood )。

    nll_loss 损失函数的表达式为

    相当于只取 log(softmax(x)) 预测结果向量 label 对应索引值的相反数。

    log_softmax

    log_softmax = a.log_softmax(dim=0)

    print('-- log_softmax : ', log_softmax)

    # tensor([-2.4076, -1.4076, -0.4076])

    nll_loss

    print(manual_log_softmax.unsqueeze(1))

    # tensor([[-2.4076], [-1.4076], [-0.4076]])

    nll_loss = F.nll_loss(manual_log_softmax.unsqueeze(1), torch.LongTensor([0, 0, 0]), reduction='none')

    print('-- nll_loss : ', nll_loss)

    # tensor([2.4076, 1.4076, 0.4076])



    cross_entropy_loss

    log_softmax 和 nll_loss 经常组合使用,常见于各种分类算法

    cross_entropy_loss 底层实际也是调用的 log_softmax 和 nll_loss。该算子一步到位的将 softmax、log 及 nll 操作全部完成。

    b = torch.Tensor([ [1, 2], [0.5, 3], [0.9, 4] ])

    loss = nn.CrossEntropyLoss()

    l1 = loss(b, torch.LongTensor([0, 1, 0]))

    print('-- loss : ', l1)

    # tensor(1.5121)

    相关文章

      网友评论

          本文标题:nll_loss是什么?原来是log_softmax的损失函数!

          本文链接:https://www.haomeiwen.com/subject/qfrocdtx.html