直接加载所有权重
适用于直接使用别人的模型和权重
model=CNN()
model.load_state_dict(torch.load('cnn.pth'))
加载部分权重
适用于对别人的模型做了适当的修改
model=CNN()
pretrained_dict = torch.load('cnn.pth'))
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #用于过滤掉修改结构处的权重
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
网友评论