美文网首页
检查PyTorch图的梯度流

检查PyTorch图的梯度流

作者: 药柴 | 来源:发表于2018-09-25 11:04 被阅读0次

    项目中用到了自定义的损失函数,但是在训练过程中发现损失保持不变,说明可能梯度的传导存在问题。在PyTorch论坛中的How to check for vanishing/exploding gradients发现了一个由Adam Paszke给出的较好的小程序bad_grad_viz.py,特别摘录如下:

    from graphviz import Digraph
    import torch
    from torch.autograd import Variable, Function
    
    def iter_graph(root, callback):
        queue = [root]
        seen = set()
        while queue:
            fn = queue.pop()
            if fn in seen:
                continue
            seen.add(fn)
            for next_fn, _ in fn.next_functions:
                if next_fn is not None:
                    queue.append(next_fn)
            callback(fn)
    
    def register_hooks(var):
        fn_dict = {}
        def hook_cb(fn):
            def register_grad(grad_input, grad_output):
                fn_dict[fn] = grad_input
            fn.register_hook(register_grad)
        iter_graph(var.grad_fn, hook_cb)
    
        def is_bad_grad(grad_output):
            grad_output = grad_output.data
            return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()
    
        def make_dot():
            node_attr = dict(style='filled',
                            shape='box',
                            align='left',
                            fontsize='12',
                            ranksep='0.1',
                            height='0.2')
            dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    
            def size_to_str(size):
                return '('+(', ').join(map(str, size))+')'
    
            def build_graph(fn):
                if hasattr(fn, 'variable'):
                    u = fn.variable
                    node_name = 'Variable\n ' + size_to_str(u.size())
                    dot.node(str(id(u)), node_name, fillcolor='lightblue')
                else:
                    assert fn in fn_dict, fn
                    fillcolor = 'white'
                    if any(is_bad_grad(gi) for gi in fn_dict[fn]):
                        fillcolor = 'red'
                    dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor)
                for next_fn, _ in fn.next_functions:
                    if next_fn is not None:
                        next_id = id(getattr(next_fn, 'variable', next_fn))
                        dot.edge(str(next_id), str(id(fn)))
            iter_graph(var.grad_fn, build_graph)
    
            return dot
    
        return make_dot
    
    if __name__ == '__main__':
        x = Variable(torch.randn(10, 10), requires_grad=True)
        y = Variable(torch.randn(10, 10), requires_grad=True)
    
        z = x / (y * 0)
        z = z.sum() * 2
        get_dot = register_hooks(z)
        z.backward()
        dot = get_dot()
        dot.save('tmp.dot')
    

    例程运行得到一个tmp.dot文件,可视化效果如下:

    tmp.png
    可以看到由于计算式中出现了x / (y * 0),梯度出现了问题,这两个function被标为红色。将x / (y * 0)改为x / (y * 1)后,生成的图就变成了
    tmp.png
    在本人的例子中,出现了‘NoneType' object has no attribute 'data'的问题。
    这里需要注意的是,这段代码假设了图的所有输入都是设置了requires_grad=True的,然而很多时候这种情况并不满足,例如在简单的图像分类问题中,我们对于输入图像并不要求梯度,因为我们不需要对其进行修改。因此,为了使这段代码能够直接适用于大型的模型,可以修改register_hooks(var)内的is_bad_grad(grad_output),以修正这一个错误,如下:
    def is_bad_grad(grad_output):
            try:
                grad_output = grad_output.data
            except:
                print('Fail to get grad')
                return True
            return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()
    

    不过,这个修改只是简单为了让这段代码能够工作,实际上时不符合道理的。事实上,这段代码更适合小的单元模块测试,例如自定义的Loss函数。

    相关文章

      网友评论

          本文标题:检查PyTorch图的梯度流

          本文链接:https://www.haomeiwen.com/subject/qybkoftx.html