美文网首页
torch 中多GPU训练保存的模型加载方法

torch 中多GPU训练保存的模型加载方法

作者: 小姐姐催我改备注 | 来源:发表于2019-04-10 15:22 被阅读0次

多GPU保存的模型会有前缀modules.

模型加载的方式

state_dict = torch,load("../weight.pth")
这里返回的是一个{"keys":values}
net.load_state_dicts({"key":values})
Missing key(s) in state_dict: Unexpected key(s) in state_dict:
这里会报这样的错误,因为gpu会在网络参数前面加上一个modules.
这里因为state_dict 是一个字典
net.load_state_dict({k.replace("modules","."):v for k, v in state_dict.items()})
state_dict.items()返回字典的健值。

相关文章

网友评论

      本文标题:torch 中多GPU训练保存的模型加载方法

      本文链接:https://www.haomeiwen.com/subject/yboniqtx.html