假设我们在深度学习模型中有一个这样的需求:主要模型A中包含子模块B,而模型B可以通过一定的方式得到一个预训练的权重,模型A需要利用B模型的权重,在此基础上继续训练。
首先我们到官网上去寻找,PyTorch官网上给出了一些保存和加载模型的示例,可以说非常全面总结了模型保存和加载的方法和主义事项,https://pytorch.org/tutorials/beginner/saving_loading_models.html。但是这里的方案都是针对一个完整模型的保存和加载的,不能满足我们这个需求。
因此需要基于此做一些改进,具体如代码所示:
import torch
import torch.nn as nn
class ModelA(nn.Module):
def __init__(self):
super(ModelA, self).__init__()
self.A = nn.Linear(2, 3)
def forward(self, A):
pass
class ModelB(nn.Module):
def __init__(self):
super(ModelB, self).__init__()
self.model_a = ModelA()
self.A = nn.Linear(2, 3)
def forward(self, x):
pass
print("Model")
modelA = ModelA()
modelA_dict = modelA.state_dict()
print('-' * 80)
for key in sorted(modelA_dict.keys()):
parameter = modelA_dict[key]
print(key)
print(parameter.size())
print(parameter)
modelB = ModelB()
modelB_dict = modelB.state_dict()
print('-'*80)
for key in sorted(modelB_dict.keys()):
print('-'*20)
parameter = modelB_dict[key]
print(type(key), key)
print(parameter.size())
print(parameter)
print('-'*20)
print('-'*80)
pretrained_dict = modelA_dict
model_dict = modelB_dict
pretrained_dict = {'model_a.' + k: v for k, v in pretrained_dict.items() if 'model_a.' + k in model_dict}
model_dict.update(pretrained_dict)
modelB.load_state_dict(model_dict)
modelB_dict = modelB.state_dict()
for key in sorted(modelB_dict.keys()):
parameter = modelB_dict[key]
print(key)
print(parameter.size())
print(parameter)
网友评论