最近在做CS231n的Assignment2,需要推导Batch Normalization的反向传播公式并用代码实现。自己试着用链式法则一步步求,最终得出来的式子巨复杂,而且没有求和符号,很不对劲。去看原论文给出的公式也是一头雾水:
参考了
https://www.adityaagrawal.net/blog/deep_learning/bprop_batch_norm
https://kevinzakka.github.io/2016/09/14/batch_normalization/
但里面对求导都是直接用
通过画出计算图不难理解这个式子,但怎么通过链式法则公式本身推导出来呢?自己的推导过程错在哪里呢?研究了两天,终于搞懂了!接下来详细记录一下求导过程。
复习:多元复合函数的求导法则
根据《高等数学》(同济大学第七版)下册第九章第四节,多元函数与多元函数复合的情形有如下定理:
多元复合函数的链式法则用通俗的话说,就是需要对所有跟有关的复合函数施以链式法则并求和。记住这一点,我们撸起袖子开干。
需解决的问题
Batch Normalization向前传播时各式的定义为:
反向传播时,从上游传下来,求。
推导过程
先求比较简单的:
根据式(4),所有算式中都有和,所以需要对每个进行链式法则求导并加和。同理,所有算式中都有和,且是的函数,根据多元复合函数的求导法则,有:
其中, 参考下图便于理解,其中有彩色的关系都需要有求和符号:
最后,将式(3)看成是的复合函数,则有
将式(7)(8)(9)的最终结果全部代入上式即可得:
网友评论