美文网首页
Pytorch预训练模型finetune

Pytorch预训练模型finetune

作者: 黑恶歌王 | 来源:发表于2018-12-16 11:52 被阅读0次

    这一块实在是因为之前没有过pytorch的经验,从0开始一步一步摸滚打爬。而且发现自己手总是处于闲置状态实在不好,每天还是写一些东西防止遗忘的过快吧。毕竟同一件事在不同的时间段看总会有一些不同的体会。
    首先是说这一块的预训练模型,pytorch不愧是最有用户体验的一个工具,官方文档中这里给了一个resnet50的imagenet的预训练模型,具体的引用也给了一个很好的例子。导入torchvision包就可以解决这个问题。
    下面这一个是从csdn blog中找到的一个说是能适配预训练模型的操作。

    pretrained_params = torch.load('Pretrained_Model') 
    model = The_New_Model(xxx) 
    model.load_state_dict(pretrained_params.state_dict(), strict=False)
    

    这里的load是把pretrained的模型导入进来,转化的是一个字典类型,然后如果下方load_state_dict的时候确实是把对应层的参数传到进来,strict=False给的解释是如果没有和其匹配的参数就会摒弃。这样对于现在的工作确实倒是匹配进去了,打算在服务器上试试看结果如何。
    发现上述问题并不是匹配进去了,因为所有层的关键词都有一些命名问题,这样的话首先第一点就是解决每一层命名读不到(因为是个字典类型)。更正参考nonlocal中的convert_model(在做。。)
    经过如下操作学习得到了可以将模型中对应的key进行更改。一开始的读取方式有误,实际上不用enumerate读取就可以直接得到key和value。加上enumerate相当于自己给他定义了key和v导致读取有问题。

    pickle读取问题:

    这里我们用python3实验的话会产生一个无法读取assii码的问题,用下面转换加一个encoding方式解决。这里很感谢网上一位博主的说明,十分的有针对性。

    import torch
    import torchvision
    import collections
    import pickle
    #net=torch.load('resnet50-19c8e357.pth')
    f=open('r50_pretrain_c2_model_iter450450_clean.pkl','rb')
    net=pickle.load(f,encoding='bytes')
    new_dict=collections.OrderedDict()
    print(net)
    

    然而用pickle读取转到pth并不清楚中间有什么处理过程,待研究。最后是强行存储为pth文件后会报错,暂不清楚原因。
    和实验室同一级大家比我确实拉了一大截了,这一点我很清楚,现在还是先把能做的做好,其他的在实践中慢慢补充基础知识吧...这研一还有考试是真的烦...还得为了奖学金拼一拼哎....

    相关文章

      网友评论

          本文标题:Pytorch预训练模型finetune

          本文链接:https://www.haomeiwen.com/subject/lmbtkqtx.html