模型保存
torch.save()实现对网络结构和模型参数的保存.有两种保存方式:一是保存整个神经网络的结构信息和模型参数信息.save的对象是网络net.二是只保留神经网络的训练模型参数,save的对象是net.state_dict()
torch.save('net1','model.pkl') #保留整个神经网络的结构和模型参数
torch.save(net1.state_dict(),'model.pkl') # 只保留神经网络的模型参数
模型加载
对于两种保存方式,重载也有两种方式.
对应第一种完整网络结构信息,重载的时候通过
torch.load('model.pkl')
直接初始化新的神经网络对象即可.
对应第二种只保存模型参数信息,需要首先导入对应的网络,通过
net.load_state_dict(torch.load('model.pkl'))
完成模型参数的重载
在网络比较大时,第一种方法会花费较多的时间.
网友评论