1. torch.save
主要参数
- obj: 对象
- f:输出路径
2. torch.load
主要参数
- f: 文件路径
- map_location: 指定存放位置,cpu or gpu
方法1:保存整个module (耗时,占内存)
保存:
torch.save(net.path)
加载:
path_model = './model.pkl'
net_load = torch.load(path_model)
方法2:保存模型参数(官方推荐)
保存:
state_dict = net.state_dict()
torch.save(state_dict, path)
加载:
path_state_dict = './model_state_dict.pkl'
state_dict_load = torch.load(path_state_dict)
net.load_state_dict(state_dict_load)
3. 断点续存训练
保存断点(在epoch循环中):
if (epoch + 1) % checkpoint_interval == 0: # 每隔checkpoint_interval保存一次
checkpoint = {"model_state_dict": net.state_dict() # 模型数据
"optimizer_state_dict": optimizer.state_dict() # 优化器数据
"epoch": epoch # 迭代次数
}
path_checkpoint = './checkpoint_{}_epoch.pkl'.format(epoch)
torch.save(checkpoint, path_checkpoint)
断点恢复:
path_checkpoint = './checkpoint_4_epoch.pkl'
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch
网友评论