美文网首页
Torch反向传播时出错或者梯度为NaN的问题排查

Torch反向传播时出错或者梯度为NaN的问题排查

作者: 酌泠 | 来源:发表于2021-08-23 20:56 被阅读0次

师兄问到的backward时出现NaN和错误不好定位的问题,感觉有必要记录一下

解决方案
Automatic differentiation package - torch.autograd — PyTorch 1.9.0 documentation
使用with torch.autograd.detect_anomaly():包裹传播过程检测错误

import torch
from torch import autograd

class MyFunc(autograd.Function):  # A func generate NaN when backward
    @staticmethod
    def forward(ctx, inp):
        return inp.clone()

    @staticmethod
    def backward(ctx, gO):
        grad1 = torch.zeros_like(gO) / torch.zeros_like(gO)  # NaN
        return grad1


class Net(torch.nn.Module):  # toy net
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(10, 10)
        self.l2 = torch.nn.Linear(10, 10)

    def forward(self, x):
        o1 = self.l1(x)
        o2 = MyFunc.apply(o1)
        o3 = self.l2(o2)
        return o3.sum()

with autograd.detect_anomaly():  # enable Anomaly Detection
    inp = torch.rand(10, 10, requires_grad=True)
    m = Net()
    out = m(inp)
    out.backward()
    print(inp.grad)

粗暴的方法,直接查看模型各节点的梯度是否计算正确

for name, param in m.named_parameters():
    shape, c = (param.grad.shape, param.grad.sum()) if param.grad is not None else (None, None)
    print(f'{name}: {param.shape} \n\t grad: {shape} \n\t {c}')

相关文章

  • Torch反向传播时出错或者梯度为NaN的问题排查

    师兄问到的backward时出现NaN和错误不好定位的问题,感觉有必要记录一下 解决方案Automatic dif...

  • PyTorch:动态计算态图和静态计算图

    Torch中反向传播时候其梯度时根据链式求导法则构建的动态计算图所谓的动态计算图就是在搭建网络的时候自动构建的计算...

  • 机器学习速成课程 学习笔记23:训练神经网络

    失败案例 很多常见情况都会导致反向传播算法出错。 梯度消失 较低层(更接近输入)的梯度可能会变得非常小。在深度网络...

  • 激活函数

    1、非线性激活函数 sigmoid、tanh 问题:1、计算量大;2、容易有梯度消失问题 梯度消失问题:在反向传播...

  • 机器学习分享——反向传播算法推导

    反向传播(英语:Backpropagation,缩写为BP)是“误差反向传播”的简称,是一种与最优化方法(如梯度下...

  • 反向传播

    反向传播(英語:Backpropagation,缩写为BP)是“误差反向传播”的简称,是一种与最优化方法(如梯度下...

  • 笔记6-Deep learning and backpropag

    优化算法:梯度下降,反向传播(BP)是梯度下降实现方法之一。

  • 深度学习 | 梯度消散/爆炸

    1.梯度消散/爆炸 原因 梯度消散:反向传播是逐层对函数偏导相乘,因此网络很深时,最后的偏差会越来越小直到为0。 ...

  • TensorFlow2 自动微分机制

    神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情。而深度学习框架可以帮助...

  • Batch Normalization

    一、不使用Batch Normalization 对某层的前向传播过程有:针对该层的反向传播过程为:连续多层的梯度...

网友评论

      本文标题:Torch反向传播时出错或者梯度为NaN的问题排查

      本文链接:https://www.haomeiwen.com/subject/hfnliltx.html