Pytorch 框架下,模型的保存和加载有两种方式,一种是只保存模型参数,一种是保存模型网络及参数,两种加载的方式也不一样。
保存模型参数的方式
torch.save(model.state_dict(), 'model_path')
model.load_state_dict(torch.load('model_path'))
保存模型网络和参数的方式
torch.save(model, 'model_path')
model = torch.load('model_path')
Pytorch 框架下,模型的保存和加载有两种方式,一种是只保存模型参数,一种是保存模型网络及参数,两种加载的方式也不一样。
保存模型参数的方式
torch.save(model.state_dict(), 'model_path')
model.load_state_dict(torch.load('model_path'))
保存模型网络和参数的方式
torch.save(model, 'model_path')
model = torch.load('model_path')
本文标题:Pytorch: 模型保存和加载
本文链接:https://www.haomeiwen.com/subject/zqptsctx.html
网友评论