美文网首页
Pytorch: 加载部分权重

Pytorch: 加载部分权重

作者: wzNote | 来源:发表于2019-11-21 23:54 被阅读0次

直接加载所有权重

适用于直接使用别人的模型和权重

model=CNN()
model.load_state_dict(torch.load('cnn.pth'))

加载部分权重

适用于对别人的模型做了适当的修改

model=CNN()
pretrained_dict = torch.load('cnn.pth'))
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #用于过滤掉修改结构处的权重
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

相关文章

网友评论

      本文标题:Pytorch: 加载部分权重

      本文链接:https://www.haomeiwen.com/subject/pdgxwctx.html