美文网首页
Pytorch 之 模型的保存与调用

Pytorch 之 模型的保存与调用

作者: Allard_c205 | 来源:发表于2021-08-19 15:28 被阅读0次

    介绍关于用pytorch搭建模型时,对模型进行保存以及再次调用模型参数的相关函数命令。

    使用 torch.save(model.state_dict(), PATH)来保存模型学习到的参数,给模型恢复提供最大的灵活性。

    先对模型进行实例化,再用load_state_dict()调用模型,在对模型进行推理之前,调用model.eval():

    model = TheModelClass(*args, **kwargs)

    model.load_state_dict(torch.load(PATH))  #该函数只接收字典对象,而不是保存对象的路径,在这之前要反序列化保存的state_dict。

    model.eval()


    torch.load( f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>,  **pickle_load_args)  从文件加载用torch.save()保存的对象。  目前需要知道该函数前两个参数的正确使用即可

    f: 类似于文件的对象,或包含文件名称的字符串,如:要载入的模型所在的完整路径的字符串

    map_location: 一个函数,torch.device,字符串或字典,明确如何重映射存储空间位置

    pickle_module:用于解开元数据和对象的模块(必须与序列化文件的pickle_module相匹配)

    pickle_load_args:(只有Python3才有)可选择的关键字参数,并传递给pickle_module.load()和pickle_module.Unpickler(),比如,errors=...。

    相关文章

      网友评论

          本文标题:Pytorch 之 模型的保存与调用

          本文链接:https://www.haomeiwen.com/subject/fhzobltx.html