这个问题好像目前还没有一篇博客有过系统的介绍。这里我简单说一下吧。
首先为什么要切断梯度的反向传播?这里要先从模型的fine-tunning说起。
提到pytorch下的fine-tunning操作。网上的方法也是各式各样。基本上是先读取一个已有的模型,然后读取其中部分层,然后删去不需要的层,增加需要重新训练的层,往往是最后的FC。然后将这两部分设为不同的学习率。这一部分操作网上很多,我就不再赘述了,有时间会写一个这样的专题。整理一下各式各样的方法。
除去fine tunning很多时候我们还需要进行的一个操作是保持原网络参数不变,只训练部分分支,或少数几层网络;或者是我们一起训练网络,但是并不希望一些分支的梯度对主干网络的梯度产生影响。这时我们需要切断这些分支的反向传播。
在pytorch中通过拷贝需要切断位置前的tensor实现这个功能。tensor中拷贝的函数有两个,一个是clone(),另外一个是copy_()。我发现很多人会选择使用clone(),包括我的leader也是这么教我的,但是这是错误的这里需要说明的是,clone()相当于完全复制了之前的tensor,他的梯度也会复制,而且在反向传播时,克隆的样本和结果是等价的,可以简单的理解为clone只是给了同一个tensor不同的代号,和‘=’等价。所以如果想要生成一个新的分开的tensor,请使用copy_()。
不过对于这样的操作,pytorch中有专门的函数——detach()。所以大家这么用就好了:
x_new = x.detach()
网友评论