美文网首页Pytorch
PyTorch模型保存深入理解

PyTorch模型保存深入理解

作者: 西北小生_ | 来源:发表于2020-02-29 00:28 被阅读0次

前面写过一篇PyTorch保存模型的文章:Pytorch模型保存与加载,并在加载的模型基础上继续训练 ,简单介绍了PyTorch模型保存和加载函数的用法,足以快速上手,但对相关函数和参数的具体用法和代表的含义没有进行展开介绍,这篇文章用于记录之。

PyTorch保存模型的语句是这样的:
torch.save(model.state_dict(), path)
加载是这样的:
model.load_state_dict(torch.load(path))

下面我们将其拆开逐句介绍,深入理解。

1.torch.save()和torch.load()

顾名思义,save函数是PyTorch的存储函数,load函数则是读取函数。save函数可以将各种对象保存至磁盘,包括张量,列表,ndarray,字典,模型等;而相应地,load函数将保存在磁盘上的对象读取出来。

用法:

torch.save(保存对象, 保存路径)
torch.load(文件路径)

应用举例:

保存张量

In [3]: a = torch.ones(3)                                                       

In [4]: a                                                                       
Out[4]: tensor([1., 1., 1.])

In [5]: torch.save(a, './a.pth')          # 保存Tensor               

In [6]: a_load = torch.load('./a.pth')    # 读取Tensor

In [7]: a_load                                                                  
Out[7]: tensor([1., 1., 1.])

保存字典

In [11]: b = {k:v for v,k in enumerate('abc',1)}                                

In [12]: b                                                                      
Out[12]: {'a': 1, 'b': 2, 'c': 3}

In [13]: torch.save(b, './b.rar')                        

In [14]: torch.load('./b.rar')                           
Out[14]: {'a': 1, 'b': 2, 'c': 3}

可以看出,保存和读取非常方便。这里需要注意的是文件的命名,命名必须要有扩展名,扩展名可以为‘xxx.pt’,‘xxx.pth’,‘xxx.pkl’,‘xxx.rar’等形式。

2.model.state_dict()

在PyTorch中,state_dict是一个从参数名称隐射到参数Tesnor的字典对象

In [15]: class MLP(nn.Module): 
    ...:     def __init__(self): 
    ...:         super(MLP, self).__init__() 
    ...:         self.hidden = nn.Linear(3, 2) 
    ...:         self.act = nn.ReLU() 
    ...:         self.output = nn.Linear(2, 1) 
    ...:  
    ...:     def forward(self, x): 
    ...:         a = self.act(self.hidden(x)) 
    ...:         return self.output(a) 
    ...:                                                                        

In [16]: net = MLP()                                                            

In [17]: net.state_dict()                                                       
Out[17]: 
OrderedDict([('hidden.weight', tensor([[ 0.4839,  0.0254,  0.5642],
                      [-0.5596,  0.2602, -0.5235]])),
             ('hidden.bias', tensor([-0.4986, -0.5426])),
             ('output.weight', tensor([[0.0967, 0.4980]])),
             ('output.bias', tensor([-0.4520]))])

可以看出,state_dict()返回的是一个有序字典,该字典的键即为模型定义中有参数的层的名称+weight或+bias,值则对应相应的权重或偏差,无参数的层则不在其中。

除了模型中有参数的层(卷积层、线性层等)有state_dict,优化器对象:

optimizer = torch.optim.xxxx(...)    # 如SGD,Adam等

也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。以及,学习率调整器对象:

scheduler = torch.optim.lr_scheduler.xxxx(...)    # 如LambdaLR,CosineAnnealingLR等

也有一个state_dict,其中包含当前学习率的值以及迭代次数记录。

如果有程序中断后继续接着训练的需求,最好将这些状态字典都以字典形式保存下来:

check_point = {'lr': scheduler.state_dict(), 'optimizer': optimizer.state_dict(), 'model': model.state_dict()}
torch.save(check_point, path)

恢复时只需要在相应对象实例化之后进行加载即可:

check_point = torch.load(path) 
... ...
model = xxxNet(...)
model.load_state_dict(check_point['model'])
... ... 
optimizer = torch.optim.xxxx(...)
optimizer.load_state_dict(check_point['optimizer'])
... ... 
scheduler = torch.optim.lr_scheduler.xxxxLR(...) 
scheduler.load_state_dict(check_point['lr']) 

3.model.load_state_dict()

这是模型加载state_dict的语句,也就是说,它的输入是一个state_dict,也就是一个字典。模型定义好并且实例化后会自动进行初始化,上面的例子中我们定义的模型MLP在实例化以后显示的模型参数都是自动初始化后的随机数。

在训练模型或者迁移学习中我们会使用已经训练好的参数来加速训练过程,这时候就用load_state_dict()语句加载训练好的参数并将其覆盖在初始化参数上,也就是说执行过此语句后,加载的参数将代替原有的模型参数。

既然加载的是一个字典,那么需要注意的就是字典的键一定要相同才能进行覆盖,比如加载的字典中的'hidden.weight'只能覆盖当前模型的'hidden.weight',如果键不同,则不能实现有效覆盖操作。键相同而值的shape不同,则会将新的键值对覆盖原来的键值对,这样在训练时会报错。所以我们在加载前一般会进行数据筛选,筛选是对字典的键进行对比来操作的:

pretrained_dict = torch.load(log_dir)  # 加载参数字典
model_state_dict = model.state_dict()  # 加载模型当前状态字典
pretrained_dict_1 = {k:v for k,v in pretrained_dict.items() if k in model_state_dict}  # 过滤出模型当前状态字典中没有的键值对
model_state_dict.update(pretrained_dict_1)  # 用筛选出的参数键值对更新model_state_dict变量
model.load_state_dict(model_state_dict)  # 将筛选出的参数键值对加载到模型当前状态字典中

以上代码简单的对预训练参数进行了过滤和筛选,主要是通过第3条语句粗略的过滤了键值对信息,进行筛选后要用Python更新字典的方法update()来对模型当前字典进行更新,update()方法将pretrained_dict_1中的键值对添加到model_state_dict中,若pretrained_dict_1中的键名和model_state_dict中的键名相同,则覆盖之;若不同,则作为新增键值对添加到model_state_dict中。显然,这里需要的是将pretrained_dict_1中的键值对覆盖model_state_dict的相应键值对,所以对应的键的名称必须相同,所以第3条语句中按键名称进行筛选,过滤出当前模型字典中没有的键值对。否则会报错。

如果想要细粒度过滤或更改某些参数的维度,如进行卷积核参数维度的调整,假如预训练参数里conv1有256个卷积核,而当前模型只需要200个卷积核,那么可以采用类似以下语句直接对字典进行更改:

pretrained_dict['conv1.weight'] = pretrained_dict['conv1.weight'][:200,:,:,:]   # 假设保留前200个卷积核

以上。

相关文章

  • PyTorch模型保存深入理解

    前面写过一篇PyTorch保存模型的文章:Pytorch模型保存与加载,并在加载的模型基础上继续训练 ,简单介绍了...

  • pytorch之保存与加载模型

    pytorch之保存与加载模型 本篇笔记译自pytorch官网tutorial,用于方便查看。pytorch与保存...

  • Pytorch模型保存与加载,并在加载的模型基础上继续训练

    pytorch保存模型非常简单,主要有两种方法: 只保存参数;(官方推荐) 保存整个模型 (结构+参数)。由于保存...

  • Pytorch Tips

    保存、恢复模型参数参考:pytorch学习笔记(五):保存和加载模型 中断时保存参数 将该代码添加至save_mo...

  • Pytorch深度模型保存和加载

    Pytorch保存模型的两种方式: 1 模型结构和模型参数都保存下来 优点:不需要预初始化模型,直接加载,就可以获...

  • Pytorch: 模型保存和加载

    Pytorch 框架下,模型的保存和加载有两种方式,一种是只保存模型参数,一种是保存模型网络及参数,两种加载的方式...

  • 模型保存

    关于pytorch模型保存,在训练过程中常用,记录总结一下 如何保存和重新加载微调模型,通常需要保存三种文件类型才...

  • pytorch--1数据加载

    构建数据Dataset和DataLoader 构建网络 参考: PyTorch之保存加载模型pytorchyolo...

  • Pytorch 之 模型的保存与调用

    介绍关于用pytorch搭建模型时,对模型进行保存以及再次调用模型参数的相关函数命令。 使用torch.save(...

  • pytorch:model save & model load

    pytorch的模型保存与恢复~ 首先pytorch官网doc中推荐两种方法。link 然而在需要注意的是: 方法...

网友评论

    本文标题:PyTorch模型保存深入理解

    本文链接:https://www.haomeiwen.com/subject/pwjdhhtx.html