作 者: 月牙眼的楼下小黑
联 系: zhanglf_tmac (Wechat)
声 明: 欢迎转载本文中的图片或文字,请说明出处
看 pytorch
官方文档中对 CrossEntropyLoss()
的介绍,会产生一种错觉: pytorch
中的CrossEntropyLoss
似乎无法应用于多类别的图像语义分割任务。
其实: pytorch
中的CrossEntropyLoss
是可以直接应用于语义分割任务的。
我们不妨假设一个分割网络的输出形状为: (channel = 3, width = 2, height = 2) ,即 2 x 2
分辨率的图像,其中每个像素可能属于 {0,1,2} 三类中的其中一类。
import torch
from torch import nn
from torch.autograd import Variable
input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)
loss = nn.CrossEntropyLoss()
print('loss: ', loss(input, target))
input: Variable containing:
(0 ,0 ,.,.) =
1 1
1 1
(0 ,1 ,.,.) =
1 1
1 1
(0 ,2 ,.,.) =
1 1
1 1
[torch.FloatTensor of size 1x3x2x2]
target: Variable containing:
(0 ,.,.) =
0 1
1 0
[torch.LongTensor of size 1x2x2]
loss: Variable containing:
1.0986
[torch.FloatTensor of size 1]
我们讨论一下两个细节:
问题1: 输出的 loss 形状为什么是 1x 1 ?
默认情况下,即 size_average = True
, loss
会在每个 mini-batch
(小批量) 上取平均值. 如果字段 size_average
被设置为 False
, loss
将会在每个 mini-batch
(小批量) 上累加, 而不会取平均值.
那么这个 mini_batch_size
等于几呢? 在程序中,网络输出形状为 4-d Tensor
: ( batch_size
, channel
, width
, height
)。 注意: mini_batch_size != batch_size, 而是: mini_batch_size = batch_size * width * height.
这非常好理解,因为语义分割本质上是 pixel-level classification
, 所以 mini_batch_size
就等于一个 batch
图像中的 像素总数。
我们可以将上面代码中 loss
参数 size_average
设为 False
, 做个简单的验证:
import torch
from torch import nn
input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
print('input:', input)
print('target:', target)
loss = nn.CrossEntropyLoss(size_average=False)
print('loss', loss(input, target))
此时输出的 loss
值为: 4.3944
, 正好是 1.0986
的 1 x 2 x 2
倍。
问题2:如何得到每个 pixel 的 loss ?
只需将loss
参数 reduce
设为 False
即可。若网络输出形状为 4-d Tensor
: ( batch_size
, channel
, width
, height
), 此时 loss
函数会返回一个 3-d Tensor:
(batch_size
, width
, height
), 每个元素对应一个 pixel
的 loss
值。
import torch
from torch import nn
input = Variable(torch.ones(1,3,2,2), requires_grad=True)
target = Variable(torch.LongTensor([[[0,1],[1,0]]]))
loss = nn.CrossEntropyLoss(reduce=False)
print('loss: ', loss(input, target))
loss: Variable containing:
(0 ,.,.) =
1.0986 1.0986
1.0986 1.0986
[torch.FloatTensor of size 1x2x2]
网友评论