make_dot可以打印神经网络结构,并存储。
1. 库
from torchviz import make_dot
2.实现
net_struct = make_dot(net_out)
net_struct.render("net_struct", view=False)
由于在linux系统使用该函数,没有可视化,因此将view设置为False即可将网络结构存储为pdf格式。
3.示例
import tensorwatch as tw
import torchvision.models
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchviz import make_dot
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input channel, 6 output channel, 5 conv
self.conv1 = nn.Conv2d(1, 6, 5, bias=False)
self.conv2 = nn.Conv2d(6, 10, 5, bias=False)
self.bn1 = nn.BatchNorm2d(10, eps=1e-05, momentum=0.1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn1(x)
x = self.relu(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *=s
return num_features
net = Net()
print(net)
x = torch.zeros(1, 1, 20, 20, dtype=torch.float, requires_grad=False)
net_out = net(x)
net_struct = make_dot(net_out) # plot graph of variable, not of a nn.Module
net_struct.render("net_struct", view=False)
打印结果如图所示,可以看到网络结构中对应的conv,bn,relu层及相关参数size
网友评论