美文网首页
pytorch: NLLLoss

pytorch: NLLLoss

作者: Andy512 | 来源:发表于2019-06-20 00:53 被阅读0次

In the pytorch NLLLoss, if the label is a value which is negative, the return loss will be zero:

from __future__ import print_function, absolute_import

import torch
import torch.nn as nn
from torch.autograd import Variable

class testCrossEntropy(nn.Module):
    def __init__(self):
        super(testCrossEntropy, self).__init__()
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, inputs, target):
        return self.criterion(inputs, target)

if __name__ == '__main__':
    predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0]])
    crit = testCrossEntropy()
    v1 = crit(Variable(predict.log()), Variable(torch.LongTensor([-100])))
    v2 = crit(Variable(predict.log()), Variable(torch.LongTensor([1])))

v1=0.
v2=1.6094
NLLLoss is the core part of cross entropy in pytorch.

相关文章

网友评论

      本文标题:pytorch: NLLLoss

      本文链接:https://www.haomeiwen.com/subject/bljrqctx.html