RetNet

作者: 放开那个BUG | 来源:发表于2018-08-14 11:45 被阅读25次

       深层网络一般很难训练,要么梯度弥散要么梯度爆炸。但是我们考虑这样一个事实:假设我们有一个浅层网络,网络是很同意训练的。现在我们不断增加新层来建立深层网络,一个极端的情况是增加的层什么也不学习,仅仅只是拷贝浅层的输出,即这样的新层是恒等映射(Identity mapping)。这样网络至少可以和浅层一样,不应该出现退化的问题。残差学习的思路基于这个。
       对于一个堆叠层结构,当输入为x时候,学习到的特征记为H(x),现在我们希望可以学习到残差F(x) = H(x) - x,这样其实原始的学习特征为F(x) + x。之所以这样是因为残差学习相比原始特征直接学习更容易。当残差为0时,此时堆叠层仅仅做了恒等映射,至少网络性能不会下降。事实上,残差不可能为0,这样堆叠层就会在输入特征基础上学习到新的特征。残差单元如图所示。


       为什么残差学习相对更容易,从直观上来看残差学习需要学习的内容少,因为残差一般比较小,学习难度小点。从数学角度来说,首先残差单元可以表示成:
    y_l = h(x_l) + F(x_l,W_l) x_{l+1} = f(y_1)

       其中,x_lx_{l+1}分别表示第l个残差单元的输入和输出(ps:每个残差单元一般包含多层结构)。F是残差函数,表示学习到的残差,而h(x_l) = x_l表示恒等映射,f是relu激活函数。基于上式,我们求得从浅层l到深层L的学习特征为:
       x_L = x_l + \sum\limits_{i=l}^{L-1}F(x_i,W_i)
       链式求导可知:
    \frac{\partial Loss}{\partial x_l} = \frac{\partial Loss}{\partial x_L} . \frac{\partial x_L}{\partial x_l} = \frac{\partial Loss}{\partial x_L} . (1 + \frac{\partial \sum\limits_{i=l}^{L-1}F(x_i,W_i)}{\partial x_l})

    公式中的第一个因子\frac{\partial Loss}{\partial x_L}表示损失函数到达L的梯度,小括号中的1表示短路机可以无损的传播梯度,而另外一项残差梯度则需要经过待遇weights的层,梯度不是直接传过来的。残差梯度不会那么巧全为-1,而且就算其比较小,有1的存在也不会导致梯度消失。所以残差学习会更容易。
       吴恩达说a^{l+2} = g(z^{[l+2]} + a^{[l]}) = g(w^{[l+2]} \times a^{[l+1]} + b^{[l+2]}+ a^{[l]}),激活函数用的relu,如果z^{[l+2]}为负数,那么前面这块为0,只剩下a^{[l]},只是把a^{[l]}赋值给a^{[l+2]},首先对网络性能没有什么影响,然后a^{[l]}a^{[l+2]}这段的网络还可能学到一些特征。如果a^{[l]}a^{[l+2]}的维度不相同,比如,a^{[l+2]}为256维,a^{[l]}为128维,那么增加一个w_s,它为[256, 128]维,w_s \times a^{[l]}就可以得到需要输出的维度。
       至于resnet的结构,我就不说了,这个没啥好讲的。
    感谢这位博主的博客[https://zhuanlan.zhihu.com/p/31852747]

    相关文章

      网友评论

          本文标题:RetNet

          本文链接:https://www.haomeiwen.com/subject/jwjmbftx.html