美文网首页
PyTorch中如何加载子模块的权重

PyTorch中如何加载子模块的权重

作者: WritingHere | 来源:发表于2021-01-02 19:28 被阅读0次

    假设我们在深度学习模型中有一个这样的需求:主要模型A中包含子模块B,而模型B可以通过一定的方式得到一个预训练的权重,模型A需要利用B模型的权重,在此基础上继续训练。
    首先我们到官网上去寻找,PyTorch官网上给出了一些保存和加载模型的示例,可以说非常全面总结了模型保存和加载的方法和主义事项,https://pytorch.org/tutorials/beginner/saving_loading_models.html。但是这里的方案都是针对一个完整模型的保存和加载的,不能满足我们这个需求。
    因此需要基于此做一些改进,具体如代码所示:

    import torch
    import torch.nn as nn
    
    class ModelA(nn.Module):
        def __init__(self):
            super(ModelA, self).__init__()
            self.A = nn.Linear(2, 3)
        
        def forward(self, A):
            pass
    
    
    class ModelB(nn.Module):
        def __init__(self):
            super(ModelB, self).__init__()
            self.model_a = ModelA()
            self.A = nn.Linear(2, 3)
        
        def forward(self, x):
            pass
    
    print("Model")
    modelA = ModelA()
    modelA_dict = modelA.state_dict()
    print('-' * 80)
    for key in sorted(modelA_dict.keys()):
        parameter = modelA_dict[key]
        print(key)
        print(parameter.size())
        print(parameter)
    modelB = ModelB()
    modelB_dict = modelB.state_dict()
    print('-'*80)
    for key in sorted(modelB_dict.keys()):
        print('-'*20)
        parameter = modelB_dict[key]
        print(type(key), key)
        print(parameter.size())
        print(parameter)
        print('-'*20)
    print('-'*80)
    
    pretrained_dict = modelA_dict
    model_dict = modelB_dict
    
    pretrained_dict = {'model_a.' + k: v for k, v in pretrained_dict.items() if 'model_a.' + k in model_dict}
    model_dict.update(pretrained_dict)
    
    modelB.load_state_dict(model_dict)
    modelB_dict = modelB.state_dict()
    for key in sorted(modelB_dict.keys()):
        parameter = modelB_dict[key]
        print(key)
        print(parameter.size())
        print(parameter)
    

    相关文章

      网友评论

          本文标题:PyTorch中如何加载子模块的权重

          本文链接:https://www.haomeiwen.com/subject/yxtyoktx.html