老娘真的要吐血啦, pytorch版本load pretrained model不兼容和matlab is unable to call pytorch0.4。
至于为什么matlab不能call pytorch0.4, 网上有很多猜测,比如matlablibc++版本太老,线程的原因,反正都不靠谱,只能回退。
测试别人的网络,要求pytorch0.4以上,然后我就更新啦, 并且在0.4的版本下训练了自己的网络,然后我发现matlab无法call pytorch0.4. 所以我就回退了版本
conda install pytorch=0.3.1 torchvision cuda80 -c pytorch
回退以后出现pytorch无法load_state_dict,会出现各种各样的unexpected错误。
error 1
error: ‘module’ object has no attribute ‘_rebuild_tensor_v2’
solution
在import torch 之后,加上
import torch._utils
try:
torch._utils._rebuild_tensor_v2
except AttributeError:
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor= torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad= requires_grad
tensor._backward_hooks= backward_hooks
return tensor
torch._utils._rebuild_tensor_v2= _rebuild_tensor_v2
error 2
Unexpected key(s) in state_dict: batches_tracked”
solution
pretrained_dict= torch.load('path to .pth.tar') %自己的路径
model_dict= model.state_dict()
pretrained_dict= {k: vfor k, vin pretrained_dict.items()if kin model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(pretrained_dict)
error 3
nn.InstanceNorm3d 无法存储,目前还没有找到答案
网友评论