美文网首页
Pytorch Tips

Pytorch Tips

作者: SnorlaxSE | 来源:发表于2019-10-18 16:11 被阅读0次
    # 保存和加载整个模型
    torch.save(model_object, 'model.pkl')
    model = torch.load('model.pkl')
    
    # 仅保存和加载模型参数(推荐使用)
    torch.save(model_object.state_dict(), 'params.pkl')
    model_object.load_state_dict(torch.load('params.pkl'))
    
    • 中断时保存参数
    try:
        train_net(net=net, epochs=args.epochs, batch_size=args.batchsize,
                  lr=args.lr, gpu=args.gpu, img_scale=args.scale)
    except KeyboardInterrupt:  # 用户中断执行(通常是输入^C)
        import time
        save_time = time.strftime("%Y-%m-%d-%H-%M", time.localtime())
        torch.save(net.state_dict(), '{}_INTERRUPTED.pth'.format(save_time))
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
    

    将该代码添加至save_model合适的位置,可实现“Early Stopping”

    相关文章

      网友评论

          本文标题:Pytorch Tips

          本文链接:https://www.haomeiwen.com/subject/yfzimctx.html