美文网首页
Pytorch Tips

Pytorch Tips

作者: SnorlaxSE | 来源:发表于2019-10-18 16:11 被阅读0次
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
  • 中断时保存参数
try:
    train_net(net=net, epochs=args.epochs, batch_size=args.batchsize,
              lr=args.lr, gpu=args.gpu, img_scale=args.scale)
except KeyboardInterrupt:  # 用户中断执行(通常是输入^C)
    import time
    save_time = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
    torch.save(net.state_dict(), '{}_INTERRUPTED.pth'.format(save_time))
    print('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)

将该代码添加至save_model合适的位置,可实现“Early Stopping”

相关文章

网友评论

      本文标题:Pytorch Tips

      本文链接:https://www.haomeiwen.com/subject/yfzimctx.html