torchviz:https://github.com/szagoruyko/pytorchviz
转载于:使用pytorchviz进行pytorch执行过程的可视化 - pytorch中文网
1. 安装
pip install graphviz
pip install git+https://github.com/szagoruyko/pytorchviz
import torch
from torch.autograd import Variable
from torch import nn
from torchviz import make_dot, make_dot_from_trace
model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))
x = Variable(torch.randn(1,8))
y = model(x)
make_dot(y.mean(), params=dict(model.named_parameters()))

- 主要有两个函数,
make_dot
可以从任何PyTorch
函数(要求至少有一个输入变量requires_grad
)中生成图形,并make_dot_from_trace
使用输出torch.jit.trace
(并不总是有效)。参见examples.ipynb。
网友评论