美文网首页
pytorch计算FLOPs

pytorch计算FLOPs

作者: 顾北向南 | 来源:发表于2019-07-06 13:18 被阅读0次

    转载于机器之心:https://mp.weixin.qq.com/s/EXuFXbPBIbzTyi0fUjvvPw

    1. 引言

    • 其实模型的参数量好算,但浮点运算数并不好确定,我们一般也就根据参数量直接估计计算量了。但是像卷积之类的运算,它的参数量比较小,但是运算量非常大,它是一种计算密集型的操作。反观全连接结构,它的参数量非常多,但运算量并没有显得那么大。
    • 此外,机器学习还有很多结构没有参数但存在计算,例如最大池化和Dropout等。因此,PyTorch-OpCounter 这种能直接统计 FLOPs 的工具还是非常有吸引力的。
    • PyTorch-OpCounter GitHub 地址:https://github.com/Lyken17/pytorch-OpCounter

    2. OpCouter

    • PyTorch-OpCounter 的安装和使用都非常简单,并且还能定制化统计规则,因此那些特殊的运算也能自定义地统计进去。
    • 我们可以使用 pip 简单地完成安装:pip install thop。不过 GitHub 上的代码总是最新的,因此也可以从 GitHub 上的脚本安装。
    • 对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成:
    from torchvision.models import resnet50
    from thop import profile
    
    model = resnet50()
    input = torch.randn(1, 3, 224, 224)
    flops, params = profile(model, inputs=(input, ))
    
    • 我们测试了一下 DenseNet-121,用 OpCouter 统计了参数量与运算量。API 的输出如下所示,它会告诉我们具体统计了哪些结构,它们的配置又是什么样的。
    • 最后输出的浮点运算数和参数量分别为如下所示,换算一下就能知道 DenseNet-121 的参数量约有 798 万,计算量约有 2.91 GFLOPs。
    flops: 2914598912.0
    parameters: 7978856.0
    

    3. OpCouter 是怎么算的

    • 我们可能会疑惑,OpCouter 到底是怎么统计的浮点运算数。其实它的统计代码在项目中也非常可读,从代码上看,目前该工具主要统计了视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示:
    
    def count_conv2d(m, x, y):
        x = x[0]
    
        cin = m.in_channels
        cout = m.out_channels
        kh, kw = m.kernel_size
        batch_size = x.size()[0]
    
        out_h = y.size(2)
        out_w = y.size(3)
    
        # ops per output element
        # kernel_mul = kh * kw * cin
        # kernel_add = kh * kw * cin - 1
        kernel_ops = multiply_adds * kh * kw
        bias_ops = 1 if m.bias is not None else 0
        ops_per_element = kernel_ops + bias_ops
    
        # total ops
        # num_out_elements = y.numel()
        output_elements = batch_size * out_w * out_h * cout
        total_ops = output_elements * ops_per_element * cin // m.groups
    
        m.total_ops = torch.Tensor([int(total_ops)])
    
    • 总体而言,模型会计算每一个卷积核发生的乘加运算数,再推广到整个卷积层级的总乘加运算数。

    4. 定制你的运算统计

    • 有一些运算统计还没有加进去,如果我们知道该怎样算,那么就可以写个自定义函数。
    
    class YourModule(nn.Module):
        # your definition
    def count_your_model(model, x, y):
        # your rule here
    
    input = torch.randn(1, 3, 224, 224)
    flops, params = profile(model, inputs=(input, ),
                            custom_ops={YourModule: count_your_model})
    
    • 最后,作者利用这个工具统计了各种流行视觉模型的参数量与 FLOPs 量:

    相关文章

      网友评论

          本文标题:pytorch计算FLOPs

          本文链接:https://www.haomeiwen.com/subject/idjphctx.html