美文网首页
pytorch 加载预训练模型

pytorch 加载预训练模型

作者: i_1312 | 来源:发表于2020-06-28 09:30 被阅读0次

1 加载内置的预训练模型

model_ft = models.resnet18(pretrained=use_pretrained)
#然后可以先预训练的模型加入新的层
self.conv1 = model_ft.conv1
self.bn = model_ft.bn

2 加载自己定义的预训练的模型

##首先保存模型到checkpoint.pth
torch.save(model.module.state_dict(), ‘checkpoint.pth’)

##然后加载预训练的模型
mymodel.load_state_dict(torch.load(‘checkpoint.pth’))

3 加载部分模型参数

# 加载模型
model_pretrained = models.resnet18(pretrained=use_pretrained)

# mymodel's state_dict,
# 如:  conv1.weight 
#     conv1.bias  
mymodelB_dict = mymodelB.state_dict()

# 将model_pretrained的建与自定义模型的建进行比较,剔除不同的
pretrained_dict = {k: v for k, v in model_pretrained.items() if k in mymodelB_dict}
# 更新现有的model_dict
mymodelB_dict.update(pretrained_dict)

# 加载我们真正需要的state_dict
mymodelB.load_state_dict(mymodelB_dict)

参考链接https://www.cnblogs.com/geo-will/p/11311608.html

相关文章

网友评论

      本文标题:pytorch 加载预训练模型

      本文链接:https://www.haomeiwen.com/subject/shrifktx.html