美文网首页
配置pytorch的网络结果可视化工具

配置pytorch的网络结果可视化工具

作者: gpfworld | 来源:发表于2019-06-28 11:19 被阅读0次

可视化工具Graphviz

一.安装
Graphviz http://www.graphviz.org/
mac用户建议直接用homebrew来安装,官网上版本比较旧

1.安装homebrew
打开终端复制、粘贴以下命令:
ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"

2.安装Graphviz
homebrew安装完毕后运行 brew install graphviz即可

安装完后事例代码:

import torchvision.models as models
import torch
from torchsummary import summary
from torch.autograd import Variable
import torch
from torch.autograd import Variable
from graphviz import Digraph
import os

def make_dot(var, params=None):
    if params is not None:
        assert all(isinstance(p, Variable) for p in params.values())
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled', shape='box', align='left',
                              fontsize='12', ranksep='0.1', height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
        return '(' + (', ').join(['%d' % v for v in size]) + ')'

    output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                # note: this used to show .saved_tensors in pytorch0.2, but stopped
                # working as it was moved to ATen and Variable-Tensor merged
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            elif var in output_nodes:
                dot.node(str(id(var)), str(type(var).__name__), fillcolor='darkolivegreen1')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)

    if isinstance(var, tuple):
        for v in var:
            add_nodes(v.grad_fn)
    else:
        add_nodes(var.grad_fn)
    return dot

if __name__=="__main__":

    os.environ["PATH"] += os.pathsep + '/Library/Python/2.7/site-packages'

    ## visual model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = models.vgg()
    model =model.to(device)

    x = Variable(torch.randn(1, 3, 224,224))
    vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
    vis_graph.view()

3、安装python的对应的包 sudo pip install graphviz
用pip安装的Graphviz,但是Graphviz不是一个python tool,你仍然需要安装GraphViz‘s executables. 查阅资料后发现,原来我没有安装GraphViz‘s executables

显示网络结构的工具torchsummary

sudo pip install torchsummary 进行安装

相关文章

网友评论

      本文标题:配置pytorch的网络结果可视化工具

      本文链接:https://www.haomeiwen.com/subject/qjkscctx.html