美文网首页
pytorch模型保存和加载

pytorch模型保存和加载

作者: sheng_pan_ai | 来源:发表于2019-02-24 15:15 被阅读0次

    模型保存

    torch.save()实现对网络结构和模型参数的保存.有两种保存方式:一是保存整个神经网络的结构信息和模型参数信息.save的对象是网络net.二是只保留神经网络的训练模型参数,save的对象是net.state_dict()

    torch.save('net1','model.pkl') #保留整个神经网络的结构和模型参数
    torch.save(net1.state_dict(),'model.pkl') # 只保留神经网络的模型参数
    

    模型加载

    对于两种保存方式,重载也有两种方式.
    对应第一种完整网络结构信息,重载的时候通过

    torch.load('model.pkl')
    

    直接初始化新的神经网络对象即可.
    对应第二种只保存模型参数信息,需要首先导入对应的网络,通过

    net.load_state_dict(torch.load('model.pkl'))
    

    完成模型参数的重载
    在网络比较大时,第一种方法会花费较多的时间.

    相关文章

      网友评论

          本文标题:pytorch模型保存和加载

          本文链接:https://www.haomeiwen.com/subject/duxbyqtx.html