美文网首页
pytorch finetune模型

pytorch finetune模型

作者: jiangwenj02 | 来源:发表于2017-11-28 14:42 被阅读0次

    pytorch finetune模型

    文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。
                                                                                           --------作者:jiangwenj02【转载请注明】


    pytorch 模型的存储与读取

    其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

    单独存储模型参数

    存储时使用:

    torch.save(the_model.state_dict(), PATH)
    

    读取时:

    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH))
    

    存储模型与参数

    存储:

    torch.save(the_model, PATH)
    

    读取:

    the_model = torch.load(PATH)
    

    模型的参数

    fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

    pytorch模型参数的形式

    模型的参数是以字典的形式存储的。

    model_dict = the_model.state_dict(),
    for k,v in model_dict.items():
        print(k)
    

    即可看到所有的键值
    如果想修改模型的参数,给相应的键值赋值即可

    model_dict[k] = new_value
    

    最后更新模型的参数

    the_model.load_state_dict(model_dict)
    

    如果模型的key值和在大数据集上训练时的key值是一样的

    我们可以通过下列算法进行读取模型

    model_dict = model.state_dict()
    
    pretrained_dict = torch.load(model_path)
     # 1. filter out unnecessary keys
    diff = {k: v for k, v in model_dict.items() if \
                k in pretrained_dict and pretrained_dict[k].size() == v.size()}
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
    pretrained_dict.update(diff)
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    

    如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

    model_dict = model.state_dict()
    
    pretrained_dict = torch.load(model_path)
    keys = []
    for k,v in pretrained_dict.items():
        keys.append(k)
    i = 0
    for k,v in model_dict.items():
        if v.size() == pretrained_dict[keys[i]].size():
            print(k, ',', keys[i])
             model_dict[k]=pretrained_dict[keys[i]]
        i = i + 1
    model.load_state_dict(model_dict)
    

    如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

    自己找对应关系,一个key对应一个key的赋值

    相关文章

      网友评论

          本文标题:pytorch finetune模型

          本文链接:https://www.haomeiwen.com/subject/rntebxtx.html