美文网首页
【pytorch】初步理解 Batch Normalizatio

【pytorch】初步理解 Batch Normalizatio

作者: 阮恒 | 来源:发表于2018-06-26 10:25 被阅读0次

前言:

    其实之前我一直以为BatchNorm就是Mini-batch SGD,可能是因为两者都有batch??直到阅读去噪经典论文DnCNN,作者在文中大夸BatchNorm,说加快了训练收敛速度,我才后知后觉的来看一下,一查吓一跳,原来我的深度学习基础真真真是相当的薄弱啊。

    言归正传,接下来开始总结我看过的BatchNorm相关介绍。接下来的内容均来自我从参考文献中整合的内容,几乎没有原创。感谢之前辛勤劳动的作者们。


1、提出原因:

    我们知道,CNN网络在训练的过程中,前一层的参数变化影响着后面层的变化(因为前面层的输出是后面的输入),而且这种影响会随着网络深度的增加而不断放大。在CNN训练时,绝大多数都采用mini-batch使用随机梯度下降算法进行训练,那么随着输入数据的不断变化,以及网络中参数不断调整,网络的各层输入数据的分布则会不断变化。 

    机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。Internal Covariate Shift 问题就是说,在训练过程中,因为各层参数老在变,所以每个隐层都会面临covariate shift的问题,也就是在训练过程中,隐层的输入分布老是变来变去。

    因为深层神经网络在做非线性变换前的激活输入值(就是那个x=WU+B,U是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致后向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

来自莫烦教程的一张图,PreAct和Act分别指有无激活函数。含有BN的行(2、4)指加入BatchNorm。

    如上图所示,第一行是没有激活函数和BN的网络激活值,可以看出数值分布往两个极端走。加入激活函数之后(第三行)就限制在了激活函数的上下确界(梯度饱和区)。而加入BatchNorm之后,激活值往中间(激活函数敏感区)移动,从而有更大的梯度。

2、概念:

    总体而言,BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,也就是说收敛地快。

下面看BatchNorm的公式:

BatchNorm算法流程

第一步是正态变换:某个神经元对应的原始的激活x通过减去mini-Batch内m个实例获得的m个激活x求得的均值E(x)并除以求得的方差Var(x)来进行转换。

但是这样有一个问题,变换以后数据是否失去了原本的分布,导致网络表达能力下降?所以为了拟合原本的分布,作者又对每个神经元增加两个调节参数(scale和shift)γ和β,这两个参数是通过训练来学习到的,目的是还原上一层应该学到的数据分布,使得网络表达能力增强。

3、优势:

    BatchNorm不仅仅极大提升了训练速度,收敛过程大大加快,还能增加分类效果,一种解释是这是类似于Dropout的一种防止过拟合的正则化表达方式,所以不用Dropout也能达到相当的效果。另外调参过程也简单多了,对于初始化要求没那么高,而且可以使用大的学习率等。

参考文献:

Batch Normalization导读

[深度学习] Batch Normalization算法介绍

Batch Normalization 批标准化

相关文章

网友评论

      本文标题:【pytorch】初步理解 Batch Normalizatio

      本文链接:https://www.haomeiwen.com/subject/ferpyftx.html