美文网首页
PyTorch如何恢复指定权重

PyTorch如何恢复指定权重

作者: yalesaleng | 来源:发表于2018-08-28 17:20 被阅读123次

    1. 如何从已训练好的网络模型中提取指定层权重

    import torch 
    # vgg为官方提供的model
    # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
    import vgg
    
    model = torch.load('logs/vgg16.pkl') 
    
    restore_param = ['classifier.2.bias']
    # 当然 如果你的目的是不想导入某些层的权重,将下述代码改为`if not k in restore_param`
    restore_param = {v for k, v in model.state_dict().items() if k in restore_param}
    print(restore_param)
    
    
    ------>:
    {tensor([-0.0048,  0.0048], device='cuda:0')}
    

    2. 如何加载模型部分参数并更新

    import torch
    import vgg
    
    model = torch.load('logs/vgg16.pkl')
    vgg16 = vgg.vgg16().cuda()
    vgg16_dict = vgg16.state_dict()
    for k, v in vgg16_dict.items():
        print(v)
    
    print()
    print('##################################################################################')
    print()
    
    restore = ['classifier.2.bias']
    restore_param = {k: v for k, v in model.state_dict().items() if k in restore}
    vgg16_dict.update(restore_param)
    for k, v in vgg16_dict.items():
        print(v)
    
    
    ------>:
    tensor([[[[-0.0198,  0.0425, -0.0221],
              [ 0.0636,  0.0193, -0.0661],
              [-0.0035,  0.0031, -0.0395]],
    
             [[-0.0525,  0.0796,  0.0263],
              [-0.0669,  0.1537,  0.1025],
              [ 0.0002, -0.0456, -0.0086]],
    
             [[-0.0344,  0.0566, -0.0090],
              [ 0.0915,  0.0133, -0.0007],
              [-0.0228, -0.0143,  0.0841]]],
    ...
    tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')
    tensor([[ 2.7670e-03, -1.6860e-02, -6.6972e-03,  ...,  6.7144e-03,
             -7.2912e-03,  2.0684e-03],
            [ 4.2978e-03, -9.8524e-03,  1.2163e-02,  ...,  6.3420e-03,
             -5.1077e-03,  6.4550e-03]], device='cuda:0')
    tensor([0., 0.], device='cuda:0')
    
    ##################################################################################
    
    tensor([[[[-0.0198,  0.0425, -0.0221],
              [ 0.0636,  0.0193, -0.0661],
              [-0.0035,  0.0031, -0.0395]],
    
             [[-0.0525,  0.0796,  0.0263],
              [-0.0669,  0.1537,  0.1025],
              [ 0.0002, -0.0456, -0.0086]],
    
             [[-0.0344,  0.0566, -0.0090],
              [ 0.0915,  0.0133, -0.0007],
              [-0.0228, -0.0143,  0.0841]]],
    ...
    tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')
    tensor([[ 2.7670e-03, -1.6860e-02, -6.6972e-03,  ...,  6.7144e-03,
             -7.2912e-03,  2.0684e-03],
            [ 4.2978e-03, -9.8524e-03,  1.2163e-02,  ...,  6.3420e-03,
             -5.1077e-03,  6.4550e-03]], device='cuda:0')
    tensor([-0.0048,  0.0048], device='cuda:0')
    

    可以发现classifier.2.bias的值由[0., 0.]变为了[-0.0048, 0.0048]

    参考文章

    相关文章

      网友评论

          本文标题:PyTorch如何恢复指定权重

          本文链接:https://www.haomeiwen.com/subject/poeqwftx.html