美文网首页
pytorch : CrossEntropyLoss 应用于语

pytorch : CrossEntropyLoss 应用于语

作者: 月牙眼的楼下小黑 | 来源:发表于2018-11-10 21:24 被阅读496次

    作 者: 月牙眼的楼下小黑
    联 系: zhanglf_tmac (Wechat)
    声 明: 欢迎转载本文中的图片或文字,请说明出处


    pytorch 官方文档中对 CrossEntropyLoss()的介绍,会产生一种错觉: pytorch中的CrossEntropyLoss似乎无法应用于多类别的图像语义分割任务。

    其实: pytorch中的CrossEntropyLoss 是可以直接应用于语义分割任务的。

    我们不妨假设一个分割网络的输出形状为: (channel = 3, width = 2, height = 2) ,即 2 x 2 分辨率的图像,其中每个像素可能属于 {0,1,2} 三类中的其中一类。

    import torch
    from torch import nn
    from torch.autograd import Variable
    
    input = Variable(torch.ones(1,3,2,2), requires_grad=True)
    target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
    print('input:', input)
    print('target:', target)
    
    loss = nn.CrossEntropyLoss()
    print('loss: ', loss(input, target))
    
    input: Variable containing:
    (0 ,0 ,.,.) = 
      1  1
      1  1
    
    (0 ,1 ,.,.) = 
      1  1
      1  1
    
    (0 ,2 ,.,.) = 
      1  1
      1  1
    [torch.FloatTensor of size 1x3x2x2]
    
    target: Variable containing:
    (0 ,.,.) = 
      0  1
      1  0
    [torch.LongTensor of size 1x2x2]
    
    loss:  Variable containing:
     1.0986
    [torch.FloatTensor of size 1]
    

    我们讨论一下两个细节:

    问题1: 输出的 loss 形状为什么是 1x 1

    默认情况下,即 size_average = True, loss 会在每个 mini-batch(小批量) 上取平均值. 如果字段 size_average 被设置为 False, loss 将会在每个 mini-batch(小批量) 上累加, 而不会取平均值.

    那么这个 mini_batch_size 等于几呢? 在程序中,网络输出形状为 4-d Tensor: ( batch_size, channel, width, height)。 注意: mini_batch_size != batch_size, 而是: mini_batch_size = batch_size * width * height.

    这非常好理解,因为语义分割本质上是 pixel-level classification, 所以 mini_batch_size 就等于一个 batch 图像中的 像素总数

    我们可以将上面代码中 loss 参数 size_average 设为 False , 做个简单的验证:

    import torch
    from torch import nn
    
    input = Variable(torch.ones(1,3,2,2), requires_grad=True)
    target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
    print('input:', input)
    print('target:', target)
    
    loss = nn.CrossEntropyLoss(size_average=False)
    print('loss', loss(input, target))
    

    此时输出的 loss 值为: 4.3944, 正好是 1.09861 x 2 x 2 倍。

    问题2:如何得到每个 pixel 的 loss ?

    只需将loss 参数 reduce 设为 False 即可。若网络输出形状为 4-d Tensor: ( batch_size, channel, width, height), 此时 loss 函数会返回一个 3-d Tensor:batch_size, width, height), 每个元素对应一个 pixelloss 值。

    import torch
    from torch import nn
    
    input = Variable(torch.ones(1,3,2,2), requires_grad=True)
    target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
    
    loss = nn.CrossEntropyLoss(reduce=False)
    print('loss: ', loss(input, target))
    
    loss: Variable containing:
    (0 ,.,.) = 
      1.0986  1.0986
      1.0986  1.0986
    [torch.FloatTensor of size 1x2x2]
    

    相关文章

      网友评论

          本文标题:pytorch : CrossEntropyLoss 应用于语

          本文链接:https://www.haomeiwen.com/subject/iertfqtx.html