报错内容:
UserWarning: Using a target size (torch.Size([1, 224, 224])) that is different to the input size (torch.Size([1, 1, 224, 224])) is deprecated. Please ensure they have the same size.
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
报错原因:是因为在相关的函数中两个矩阵的维度不一样所导致的。
解决办法:
使用torch.unsqueeze()
或者torch.squeeze()
进行升降维。
例如:
net = UNet().to(device)
net.train()
loss_fn = nn.BCELoss()
for i,(img,target) in enumerate(train_loader):
img, target = img.to(device), target.to(device)
y = net(img)
loss = loss_fn(y, target.unsqueeze(dim=0))
网友评论