师兄问到的backward时出现NaN和错误不好定位的问题,感觉有必要记录一下
解决方案
Automatic differentiation package - torch.autograd — PyTorch 1.9.0 documentation
使用with torch.autograd.detect_anomaly():
包裹传播过程检测错误
import torch
from torch import autograd
class MyFunc(autograd.Function): # A func generate NaN when backward
@staticmethod
def forward(ctx, inp):
return inp.clone()
@staticmethod
def backward(ctx, gO):
grad1 = torch.zeros_like(gO) / torch.zeros_like(gO) # NaN
return grad1
class Net(torch.nn.Module): # toy net
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(10, 10)
self.l2 = torch.nn.Linear(10, 10)
def forward(self, x):
o1 = self.l1(x)
o2 = MyFunc.apply(o1)
o3 = self.l2(o2)
return o3.sum()
with autograd.detect_anomaly(): # enable Anomaly Detection
inp = torch.rand(10, 10, requires_grad=True)
m = Net()
out = m(inp)
out.backward()
print(inp.grad)
粗暴的方法,直接查看模型各节点的梯度是否计算正确
for name, param in m.named_parameters():
shape, c = (param.grad.shape, param.grad.sum()) if param.grad is not None else (None, None)
print(f'{name}: {param.shape} \n\t grad: {shape} \n\t {c}')
网友评论