美文网首页
模型参数的裁剪

模型参数的裁剪

作者: 菌子甚毒 | 来源:发表于2023-07-03 18:07 被阅读0次

基础使用

1 如何使用id( )

我们随便定义一个模型:

m_seq = torch.nn.Sequential(
    torch.nn.Linear(2, 2),
    torch.nn.Linear(2, 2),
)

如果只是使用id(m_seq.parameters()) 只会返回整个m_seq.parameters()的一个id. 因此我们使用map().

list(map(id, m_seq.parameters())) # [140575686657984, 140575686318016, 140579143672432, 140579143765264]

可以看见返回了4个id. 它们分别是创建的两个线性层的weight 和 bias 的参数的id(['0.weight', '0.bias', '1.weight', '1.bias']).

2 使用filter根据id滤除参数

当我们明确要滤除参数的模块的时候, 可以使用下面这个方法:

def filter_params(model, to_filter_module):
    ignored_id = list(map(id, to_filter_module.parameters())) # list
    out_params = filter(lambda p: id(p) not in ignored_id, model.parameters())
    return out_params

在这个过程中filter 找到根据对比layer3里面参数的id和model里面所有参数的id(指的是list(map(id, model.parameters())))将layer3的参数滤除.
注意: 该method返回的out_params是一个filter.
使用的例子如下:
例如我们有这样一个有3层的模型:

class MyModel(nn.Module):
    def __init__(self):
        super (MyModel, self).__init__()
        self.layer1 = torch.nn.Linear(2, 2)
        self.layer2 = torch.nn.Linear(2, 2)
        self.layer3 = torch.nn.Linear(2, 2)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

model = MyModel()

我们希望第一层和第二层用lr1更新, 第三层用lr2更新. 因此我们需要将第三层和第一二层分开, 实现方式如下:

param_exclude_layer3 = filter_params(model, model.layer3) # filter
param_layer3 = model.layer3.parameters() # generator

opt = torch.optim.Adam([
    {'params': param_exclude_layer3, 'lr': 1e-3},
    {'params': param_layer3, 'lr': 1e-4},
])

滤除多层

使用以下代码滤除多层. 将需要滤除的多个模块装在list里面传入.例如: param_exclude_layer3_and_layer1 = filter_params(model, [model.layer3,model.layer1])

def filter_params(model, to_filter_module_list):
    ignored_id = []
    for module in to_filter_module_list:
        ignored_id += list(map(id, module.parameters()))
    print(ignored_id)
    out_params = filter(lambda p: id(p) not in ignored_id, model.parameters())
    return out_params

根据id来滤除的方法因为id的唯一性, 不太可能滤除错误, 但如果模型复杂, 就需要手动索引模块的位置, 例如: model.enc.linears.layer1, 这样比较麻烦, 可能需要测试定位它的位置.

相关文章

网友评论

      本文标题:模型参数的裁剪

      本文链接:https://www.haomeiwen.com/subject/whsnortx.html