美文网首页深度学习-推荐系统-CV-NLP程序员
pytorch学习9:如何切断梯度的反向传播

pytorch学习9:如何切断梯度的反向传播

作者: bdb87b292706 | 来源:发表于2018-08-02 13:56 被阅读5719次

    这个问题好像目前还没有一篇博客有过系统的介绍。这里我简单说一下吧。

    首先为什么要切断梯度的反向传播?这里要先从模型的fine-tunning说起。

    提到pytorch下的fine-tunning操作。网上的方法也是各式各样。基本上是先读取一个已有的模型,然后读取其中部分层,然后删去不需要的层,增加需要重新训练的层,往往是最后的FC。然后将这两部分设为不同的学习率。这一部分操作网上很多,我就不再赘述了,有时间会写一个这样的专题。整理一下各式各样的方法。

    除去fine tunning很多时候我们还需要进行的一个操作是保持原网络参数不变,只训练部分分支,或少数几层网络;或者是我们一起训练网络,但是并不希望一些分支的梯度对主干网络的梯度产生影响。这时我们需要切断这些分支的反向传播。

    在pytorch中通过拷贝需要切断位置前的tensor实现这个功能。tensor中拷贝的函数有两个,一个是clone(),另外一个是copy_()。我发现很多人会选择使用clone(),包括我的leader也是这么教我的,但是这是错误的这里需要说明的是,clone()相当于完全复制了之前的tensor,他的梯度也会复制,而且在反向传播时,克隆的样本和结果是等价的,可以简单的理解为clone只是给了同一个tensor不同的代号,和‘=’等价。所以如果想要生成一个新的分开的tensor,请使用copy_()。

    不过对于这样的操作,pytorch中有专门的函数——detach()。所以大家这么用就好了:

    x_new = x.detach()
    

    相关文章

      网友评论

        本文标题:pytorch学习9:如何切断梯度的反向传播

        本文链接:https://www.haomeiwen.com/subject/dxgjvftx.html