美文网首页算法
torchsummary 中input size 异常的问题

torchsummary 中input size 异常的问题

作者: 爪爪熊 | 来源:发表于2020-09-25 20:17 被阅读0次

    本文解决问题

    torchsummary针对多个输入模型的时候,其输出信息中input size等存在着错误,这里提供方案解决这个错误。


    当我们使用pytorch搭建好我们自己的深度学习模型的的时候,我们总想看看具体的网络信息以及参数量大小,这时候就要请出我们的神器 torchsummary了,torchsummary的简单使用如下所示:

    # pip install torchsummary
    from torchsummary import summary
    
    model = OurOwnModel()
    summary(model, input_size=(3, 224, 224), device='cpu')
    

    此时一切正常的话将会输出下面的信息:

    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1         [-1, 64, 224, 224]           1,792
                  ReLU-2         [-1, 64, 224, 224]               0
                Conv2d-3         [-1, 64, 224, 224]          36,928
                  ReLU-4         [-1, 64, 224, 224]               0
             MaxPool2d-5         [-1, 64, 112, 112]               0
                Conv2d-6        [-1, 128, 112, 112]          73,856
                  ReLU-7        [-1, 128, 112, 112]               0
                Conv2d-8        [-1, 128, 112, 112]         147,584
                  ReLU-9        [-1, 128, 112, 112]               0
            MaxPool2d-10          [-1, 128, 56, 56]               0
               Conv2d-11          [-1, 256, 56, 56]         295,168
                 ReLU-12          [-1, 256, 56, 56]               0
               Conv2d-13          [-1, 256, 56, 56]         590,080
                 ReLU-14          [-1, 256, 56, 56]               0
               Conv2d-15          [-1, 256, 56, 56]         590,080
                 ReLU-16          [-1, 256, 56, 56]               0
            MaxPool2d-17          [-1, 256, 28, 28]               0
               Conv2d-18          [-1, 512, 28, 28]       1,180,160
                 ReLU-19          [-1, 512, 28, 28]               0
               Conv2d-20          [-1, 512, 28, 28]       2,359,808
                 ReLU-21          [-1, 512, 28, 28]               0
               Conv2d-22          [-1, 512, 28, 28]       2,359,808
                 ReLU-23          [-1, 512, 28, 28]               0
            MaxPool2d-24          [-1, 512, 14, 14]               0
               Conv2d-25          [-1, 512, 14, 14]       2,359,808
                 ReLU-26          [-1, 512, 14, 14]               0
               Conv2d-27          [-1, 512, 14, 14]       2,359,808
                 ReLU-28          [-1, 512, 14, 14]               0
               Conv2d-29          [-1, 512, 14, 14]       2,359,808
                 ReLU-30          [-1, 512, 14, 14]               0
            MaxPool2d-31            [-1, 512, 7, 7]               0
               Linear-32                 [-1, 4096]     102,764,544
                 ReLU-33                 [-1, 4096]               0
              Dropout-34                 [-1, 4096]               0
               Linear-35                 [-1, 4096]      16,781,312
                 ReLU-36                 [-1, 4096]               0
              Dropout-37                 [-1, 4096]               0
               Linear-38                 [-1, 1000]       4,097,000
    ================================================================
    Total params: 138,357,544
    Trainable params: 138,357,544
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.57
    Forward/backward pass size (MB): 218.59
    Params size (MB): 527.79
    Estimated Total Size (MB): 746.96
    ----------------------------------------------------------------
    

    你发现一切安好,nice。但是当你像我一样开始搭建一个多输入网络的时候,这时候麻烦就来了。

    from torchsummary import summary
    
    model = OurOwnModel()
    summary(model, input_size=[(3, 224, 224), (3, 224, 224), (3, 123)], device='cpu')
    

    此时输出的信息就会有错误了。

    # 上面正确的信息省略了
    ================================================================
    Total params: 49,365,761
    Trainable params: 49,365,761
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 25169045225472.00  # 输入的大小显然不对啊
    Forward/backward pass size (MB): 22975.86
    Params size (MB): 188.32
    Estimated Total Size (MB): 25169045248636.18 # 看起来整个数据也是显然有错误的
    ----------------------------------------------------------------
    

    上面的 Input Size(MB) Estimated Total Size (MB)这两项显然是有错误的。

    这里提供如下的解决办法:

    import torchsummary
    print(torchsummary.__file__)
    

    上面代码会输出torchsummary的安装路径,这里得到的如下:

    /home/guangkun/anaconda3/envs/jet/lib/python3.7/site-packages/torchsummary/__init__.py
    

    我们知道了torchsummary的地址之后,进入该文件夹,同级目录如下:

    ├── __init__.py
    ├── __pycache__
    │   ├── __init__.cpython-37.pyc
    │   └── torchsummary.cpython-37.pyc
    └── torchsummary.py
    

    修改 torchsummary.py文件(大概在100行-103行):

      total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
      total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
      total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
      total_size = total_params_size + total_output_size + total_input_size
    

    修改为:

    total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size
    

    保存后再运行即可发现正常了,正常的输出信息如下:

    ================================================================
    Total params: 49,365,761
    Trainable params: 49,365,761
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 1.64
    Forward/backward pass size (MB): 179.50
    Params size (MB): 188.32
    Estimated Total Size (MB): 369.45
    ----------------------------------------------------------------
    

    相关文章

      网友评论

        本文标题:torchsummary 中input size 异常的问题

        本文链接:https://www.haomeiwen.com/subject/yqfquktx.html