- 保存网络结构及参数
torch.save(model,'model.pth') # 保存
model = torch.load("model.pth") # 加载
- 只保存模型参数
使用这种方法时需要创建一个和原来一模一样的模型
torch.save(model.state_dict(),"model.pth") # 保存参数
model = model() # 代码中创建网络结构
params = torch.load("model.pth") # 加载参数
model.load_state_dict(params) # 应用到网络结构中
pytorch加载官方提供预训练模型的方法请参考博客
网友评论