美文网首页机器学习神经网络pytorch
pytorch之保存与加载模型

pytorch之保存与加载模型

作者: zhaoQiang012 | 来源:发表于2019-01-14 17:31 被阅读0次

pytorch之保存与加载模型

本篇笔记译自pytorch官网tutorial,用于方便查看。
pytorch与保存、加载模型有关的常用函数3个:

  • torch.save(): 保存一个序列化的对象到磁盘,使用的是Pythonpickle库来实现的。
  • torch.load(): 解序列化一个pickled对象并加载到内存当中。
  • torch.nn.Module.load_state_dict(): 加载一个解序列化的state_dict对象

1. state_dict

PyTorch中所有可学习的参数保存在model.parameters()中。state_dict是一个Python字典。保存了各层与其参数张量之间的映射。torch.optim对象也有一个state_dict,它包含了optimizerstate,以及一些超参数。

2. 保存&加载模型来inference(recommended)

save

torch.save(model.state_dict(), PATH)

load

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()  # 当用于inference时不要忘记添加
  • 保存的文件名后缀可以是.pt.pth
  • 当用于inference时不要忘记添加model.eval()

3. 保存&加载整个模型(not recommended)

save

torch.save(model, PATH)

load

# Model class must be defined somewhere
model = torch.load()
model.eval()

4. 保存&加载带checkpoint的模型用于inferenceresuming training

save

torch.save({
  'epoch': epoch,
  'model_state_dict': model.state_dict(),
  'optimizer_state_dict': optimizer.state_dict(),
  'loss': loss,
  ...
  }, PATH)

load

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# or
model.train()

5. 保存多个模型到一个文件中

save

torch.save({
  'modelA_state_dict': modelA.state_dict(),
  'modelB_state_dict': modelB.state_dict(),
  'optimizerA_state_dict': optimizerA.state_dict(),
  'optimizerB_state_dict': optimizerB.state_dict(),
  ...
  }, PATH)

load

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelAClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict']
modelB.load_state_dict(checkpoint['modelB_state_dict']
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']

modelA.eval()
modelB.eval()
# or
modelA.train()
modelB.train()
  • 此情况可能在GANSequence-to-sequence,或ensemble models中使用
  • 保存checkpoint常用.tar文件扩展名

6. Warmstarting Model Using Parameters From A Different Model

save

torch.save(modelA.state_dict(), PATH)

load

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
  • 在迁移训练时,可能希望只加载部分模型参数,此时可置strict参数为False来忽略那些没有匹配到的keys

7. 保存&加载模型跨设备

(1) Save on GPU, Load on CPU
save

torch.save(model.state_dict(), PATH)

load

device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

(2) Save on GPU, Load on GPU
save

torch.save(model.state_dict(), PATH)

load

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

(3) Save on CPU, Load on GPU
save

torch.save(model.state_dict(), PATH)

load

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)

8. 保存torch.nn.DataParallel模型

save

torch.save(model.module.state_dict(), PATH)

load

# Load to whatever device you want

相关文章

  • pytorch之保存与加载模型

    pytorch之保存与加载模型 本篇笔记译自pytorch官网tutorial,用于方便查看。pytorch与保存...

  • PyTorch模型保存深入理解

    前面写过一篇PyTorch保存模型的文章:Pytorch模型保存与加载,并在加载的模型基础上继续训练 ,简单介绍了...

  • pytorch--1数据加载

    构建数据Dataset和DataLoader 构建网络 参考: PyTorch之保存加载模型pytorchyolo...

  • PyTorch之保存加载模型

    前提 本文来源于https://pytorch.org/tutorials/beginner/saving_loa...

  • pytorch模型加载与保存

    pytorch模型训练流程 配置超参数: epoch,learning_rate 构建数据集:训练集,验证集,测试...

  • PyTorch 之模型的保存与加载

    1. torch.save 主要参数 obj: 对象 f:输出路径 2. torch.load 主要参数 f: 文...

  • Pytorch Tips

    保存、恢复模型参数参考:pytorch学习笔记(五):保存和加载模型 中断时保存参数 将该代码添加至save_mo...

  • Pytorch: 模型保存和加载

    Pytorch 框架下,模型的保存和加载有两种方式,一种是只保存模型参数,一种是保存模型网络及参数,两种加载的方式...

  • 【深度学习DL-PyTorch】六、保存和加载模型

    使用 PyTorch 保存和加载模型。 一、 训练网络 二、 保存和加载网络 每次需要使用网络时都去训练它不太现实...

  • Pytorch深度模型保存和加载

    Pytorch保存模型的两种方式: 1 模型结构和模型参数都保存下来 优点:不需要预初始化模型,直接加载,就可以获...

网友评论

    本文标题:pytorch之保存与加载模型

    本文链接:https://www.haomeiwen.com/subject/cifjdqtx.html