美文网首页
Pytorch: 用thop计算pytorch模型的FLOPs

Pytorch: 用thop计算pytorch模型的FLOPs

作者: wzNote | 来源:发表于2019-09-29 18:46 被阅读0次

    安装thop

    pip install thop
    

    基础用法

    • 以查看resnet50的FLOPs为例
    from torchvision.models import resnet50
    from thop import profile
    model = resnet50()
    input = torch.randn(1, 3, 224, 224)
    flops, params = profile(model, inputs=(input, ))
    
    • 查看自己模型的FLOPs
    class YourModule(nn.Module):
        # your definition
    def count_your_model(model, x, y):
        # your rule here
    
    input = torch.randn(1, 3, 224, 224)
    flops, params = profile(model, inputs=(input, ), 
                            custom_ops={YourModule: count_your_model})
    
    • 提升输出结果的可读性
      调用thop.clever_format
    from thop import clever_format
    flops, params = clever_format([flops, params], "%.3f")
    

    参考:https://github.com/Lyken17/pytorch-OpCounter

    相关文章

      网友评论

          本文标题:Pytorch: 用thop计算pytorch模型的FLOPs

          本文链接:https://www.haomeiwen.com/subject/fgkqpctx.html