本文解决问题:
torchsummary
针对多个输入模型的时候,其输出信息中input size
等存在着错误,这里提供方案解决这个错误。
当我们使用pytorch
搭建好我们自己的深度学习模型的的时候,我们总想看看具体的网络信息以及参数量大小,这时候就要请出我们的神器 torchsummary
了,torchsummary
的简单使用如下所示:
# pip install torchsummary
from torchsummary import summary
model = OurOwnModel()
summary(model, input_size=(3, 224, 224), device='cpu')
此时一切正常的话将会输出下面的信息:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1,180,160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2,359,808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2,359,808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2,359,808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Linear-32 [-1, 4096] 102,764,544
ReLU-33 [-1, 4096] 0
Dropout-34 [-1, 4096] 0
Linear-35 [-1, 4096] 16,781,312
ReLU-36 [-1, 4096] 0
Dropout-37 [-1, 4096] 0
Linear-38 [-1, 1000] 4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96
----------------------------------------------------------------
你发现一切安好,nice。但是当你像我一样开始搭建一个多输入网络的时候,这时候麻烦就来了。
from torchsummary import summary
model = OurOwnModel()
summary(model, input_size=[(3, 224, 224), (3, 224, 224), (3, 123)], device='cpu')
此时输出的信息就会有错误了。
# 上面正确的信息省略了
================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 25169045225472.00 # 输入的大小显然不对啊
Forward/backward pass size (MB): 22975.86
Params size (MB): 188.32
Estimated Total Size (MB): 25169045248636.18 # 看起来整个数据也是显然有错误的
----------------------------------------------------------------
上面的 Input Size(MB) Estimated Total Size (MB)
这两项显然是有错误的。
这里提供如下的解决办法:
import torchsummary
print(torchsummary.__file__)
上面代码会输出torchsummary
的安装路径,这里得到的如下:
/home/guangkun/anaconda3/envs/jet/lib/python3.7/site-packages/torchsummary/__init__.py
我们知道了torchsummary
的地址之后,进入该文件夹,同级目录如下:
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── torchsummary.cpython-37.pyc
└── torchsummary.py
修改 torchsummary.py
文件(大概在100行-103行):
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
修改为:
total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
保存后再运行即可发现正常了,正常的输出信息如下:
================================================================
Total params: 49,365,761
Trainable params: 49,365,761
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.64
Forward/backward pass size (MB): 179.50
Params size (MB): 188.32
Estimated Total Size (MB): 369.45
----------------------------------------------------------------
网友评论