美文网首页pytorch
pytorch学习笔记-weight decay 和 learn

pytorch学习笔记-weight decay 和 learn

作者: 升不上三段的大鱼 | 来源:发表于2021-08-17 15:43 被阅读0次

    1. Weight decay

    Weight decay 是一种正则化方法,大概意思就是在做梯度下降之前,当前模型的 weight 做一定程度的 decay。
    weights_{t+1} = (1-weight\_decay)*weight_t - lr * gradient
    上面这个就相当于是 weights 减去下面公式对权重的梯度:
    \frac{weight\_decay}{2*lr}weight^2 + loss
    整理一下就是L2正则化:
    loss = loss +\frac{ weight\_decay'}{2} * L_2 (weights)

    所以当 weight\_decay' =\frac{weight\_decay}{lr} 的时候,L2正则化和 weight decay 是一样的,因此也会有人说L2正则就是权重衰减。在SGD中的确是这样,但是在 Adam中就不一定了。

    使用 weight decay 可以:

    • 防止过拟合
    • 保持权重在一个较小在的值,避免梯度爆炸。因为在原本的 loss 函数上加上了权重值的 L2 范数,在每次迭代时,模不仅会去优化/最小化 loss,还会使模型权重最小化。让权重值保持尽可能小,有利于控制权重值的变化幅度(如果梯度很大,说明模型本身在变化很大,去过拟合样本),从而避免梯度爆炸。

    在 pytorch 里可以设置 weight decay。torch.optim.Optimizer 里, SGD、ASGD 、Adam、RMSprop 等都有weight_decay参数设置:

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)
    

    参考:
    Deep learning basic-weight decay
    关于量化训练的一个小tip: weight-decay

    2. Learning rate decay

    知道梯度下降的,应该都知道学习率的影响,过大过小都会影响到学习的效果。Learning rate decay 的目的是在训练过程中逐渐降低学习率,pytorch 在torch.optim.lr_scheduler 里提供了很多花样。

    Scheduler 的定义在 optimizer之后, 而参数更新应该在一个 epoch 结束之后。

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', verbose=True)
    
    for epoch in range(10):
       for input,label in dataloader:
            optimizer.zero_grad()
            output = model(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()
    

    权重衰减(weight decay)与学习率衰减(learning rate decay)

    相关文章

      网友评论

        本文标题:pytorch学习笔记-weight decay 和 learn

        本文链接:https://www.haomeiwen.com/subject/gebbbltx.html