美文网首页
[Pytorch] 参数保存剖析

[Pytorch] 参数保存剖析

作者: 全意君 | 来源:发表于2018-08-26 16:16 被阅读0次

    一般Pytorch会将训练好的模型保存至 xxx.pth 文件中
    常用命令:
    torch.load()
    torch.save()

    详细解剖其内部:
    .pth 文件实质上是一个简单的字典文件

    module.features.0.weight 
    module.features.0.bias
    module.features.1.weight 
    module.features.1.bias 
    module.features.1.running_mean
    module.features.1.running_var
    module.features.3.weight
    module.features.3.bias
    module.features.4.weight 
    module.features.4.bias 
    module.features.4.running_mean 
    module.features.4.running_var
    ....
    module.classifier.weight
    module.classifier.bias
    
    DataParallel(
      (module): VGG(
        # 这里的 'features' 其实是自定的名称,下面的'classifier' 同理
        (features): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  # 对应上面的 features.0.weight, bias 
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) # 对应上面的 *.1.*
          (2): ReLU(inplace)  # 激活曾没有参数所以直接跳过
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
          (5): ReLU(inplace)
          (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)  # 池化曾也没有参数
    ...
          (43): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
          (44): AvgPool2d(kernel_size=1, stride=1, padding=0, ceil_mode=False, count_include_pad=True)
        )
        (classifier): Linear(in_features=512, out_features=10, bias=True)
      )
    )
    
    

    上面的DataParallel 是因为添加了
    net = torch.nn.DataParallel(net)
    这个操作使得网络可以在多GPU 上训练

    VGG(
    ....
    )
    

    pytorch 的参数存储十分简单,如果你想自定义载入的话,直接修改net.state_dict()中的参数就可以了,和python的字典处理一样

    相关文章

      网友评论

          本文标题:[Pytorch] 参数保存剖析

          本文链接:https://www.haomeiwen.com/subject/qzsoiftx.html