美文网首页
pytorch网络结构可视化

pytorch网络结构可视化

作者: 顾子豪 | 来源:发表于2020-10-12 17:13 被阅读0次
    • 1.首先安装包python-Graphviz:
    conda install -n pytorch python-graphviz
    
    • 2.保存以下代码到自己的项目路径,并保存为:visualize.py
    from graphviz import Digraph
    import torch
    from torch.autograd import Variable
    
    
    def make_dot(var, params=None):
        """ Produces Graphviz representation of PyTorch autograd graph
        Blue nodes are the Variables that require grad, orange are Tensors
        saved for backward in torch.autograd.Function
        Args:
            var: output Variable
            params: dict of (name, Variable) to add names to node that
                require grad (TODO: make optional)
        """
        if params is not None:
            assert isinstance(params.values()[0], Variable)
            param_map = {id(v): k for k, v in params.items()}
    
        node_attr = dict(style='filled',
                         shape='box',
                         align='left',
                         fontsize='12',
                         ranksep='0.1',
                         height='0.2')
        dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
        seen = set()
    
        def size_to_str(size):
            return '('+(', ').join(['%d' % v for v in size])+')'
    
        def add_nodes(var):
            if var not in seen:
                if torch.is_tensor(var):
                    dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
                elif hasattr(var, 'variable'):
                    u = var.variable
                    name = param_map[id(u)] if params is not None else ''
                    node_name = '%s\n %s' % (name, size_to_str(u.size()))
                    dot.node(str(id(var)), node_name, fillcolor='lightblue')
                else:
                    dot.node(str(id(var)), str(type(var).__name__))
                seen.add(var)
                if hasattr(var, 'next_functions'):
                    for u in var.next_functions:
                        if u[0] is not None:
                            dot.edge(str(id(u[0])), str(id(var)))
                            add_nodes(u[0])
                if hasattr(var, 'saved_tensors'):
                    for t in var.saved_tensors:
                        dot.edge(str(id(t)), str(id(var)))
                        add_nodes(t)
        add_nodes(var.grad_fn)
        return dot
    
    
    • 3.使用方法:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np
    
    class simpleconv3(nn.Module):
        def __init__(self):
            super(simpleconv3,self).__init__()
            self.conv1 = nn.Conv2d(3, 12, 3, 2)
            self.bn1 = nn.BatchNorm2d(12)
            self.conv2 = nn.Conv2d(12, 24, 3, 2)
            self.bn2 = nn.BatchNorm2d(24)
            self.conv3 = nn.Conv2d(24, 48, 3, 2)
            self.bn3 = nn.BatchNorm2d(48)
            self.fc1 = nn.Linear(48 * 5 * 5 , 1200)
            self.fc2 = nn.Linear(1200 , 128)
            self.fc3 = nn.Linear(128 , 2)
    
        def forward(self , x):
            x = F.relu(self.bn1(self.conv1(x)))
            #print "bn1 shape",x.shape
            x = F.relu(self.bn2(self.conv2(x)))
            x = F.relu(self.bn3(self.conv3(x)))
            x = x.view(-1 , 48 * 5 * 5) 
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    if __name__ == '__main__':
        import torch
        from torch.autograd import Variable
        from visualize import  make_dot
        x = Variable(torch.randn(1,3,48,48))
        model = simpleconv3()
        y = model(x)
        print(y.data)
        g = make_dot(y)
    #     g.view()
        g.render('simpleconv3Visualize', view=True)
    

    打印结果:


    simpleconv3Visualize.pdf

    相关文章

      网友评论

          本文标题:pytorch网络结构可视化

          本文链接:https://www.haomeiwen.com/subject/tdkppktx.html