美文网首页
PyTorch如何打印模型详细信息

PyTorch如何打印模型详细信息

作者: 雪糕遇上夏天 | 来源:发表于2022-08-30 16:19 被阅读0次

    我们以resnet18为例,介绍几种获取模型摘要的方法。

    import torchvistion
    model = torchvision.models.resnet18()
    

    1.直接使用PrettyTable

    from prettytable import PrettyTable
    
    table = PrettyTable(['Modules', 'Parameters']) 
    total_params = 0 
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table) 
    print(f'Total Trainable Params: {total_params}') 
    

    效果如下:


    PrettyTable

    比较简单,也没有模型的输入输出情况。

    2. TorchSummary

    from torchsummary import summary
    summary(model, input_size = (3, 64, 64), batch_size = -1)
    
    TorchSummary

    整体看美观了很多,也有了输出的维度。但是如果能打印出模型的层次结构就更好了。

    3. torchinfo

    import torchinfo 
    torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 0)
    
    torchinfo

    这种方式更加美观,且内容详细,灰常棒。

    相关文章

      网友评论

          本文标题:PyTorch如何打印模型详细信息

          本文链接:https://www.haomeiwen.com/subject/aoglnrtx.html