美文网首页
BN(Batch Normalization)在TensorFl

BN(Batch Normalization)在TensorFl

作者: EdwardLee | 来源:发表于2017-03-06 19:17 被阅读0次

    BN是Google inception系列模型里,从inception v2到inception v3的一个重要升级,在activation层之前,将卷积层的输出进行归一化,使activation的输入在[0,1]之间,避免梯度消失的问题。

    具体地,BN在TF中实现,涉及到两个方法:tf.nn.moments 和 tf.nn.batch_normalization。

    具体的方法说明请参考官方API文档。主要思路是moments计算数据的mean和variance,batch_normalization利用mean和variance计算归一化后的数据。

    一、tf.nn.moments

    def moments(x, axes, name=None, keep_dims=False)
    

    参数解释:

    ·x 可以理解为我们输出的数据,形如 [batchsize, height, width, kernels]
    ·axes 表示在哪个维度上求解,是个list,例如 [0, 1, 2]
    ·name 就是个名字,不多解释
    ·keep_dims 是否保持维度,不多解释

    这个函数的输出就是BN需要的mean和variance。
    Test code:

    import tensorflow as tf
    sess = tf.InteractiveSession()
    img = tf.random_normal([2, 3])
    axis = list(range(len(img.get_shape()) - 1))
    mean, variance = tf.nn.moments(img, axis)
    mean.eval()
    variance.eval()
    

    输出

    img = [[ 0.69495416  2.08983064 -1.08764684]
           [ 0.31431156 -0.98923939 -0.34656194]]
    mean =  [ 0.50463283  0.55029559 -0.71710438]
    variance =  [ 0.0362222   2.37016821  0.13730171]
    

    可以理解为batchsize=2,kernels=3,最终得到每个kernel对应的mean和variance。

    img=[128,32,32,64]对应的物理意义

    二、tf.nn.batch_normalization

    def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
    

    参数解释:
    ·x同moments方法
    ·mean moments方法的输出之一
    ·variance moments方法的输出之一
    ·offset BN需要学习的参数
    ·scale BN需要学习的参数
    ·variance_epsilon 归一化时防止分母为0加的一个常量

    参数对应的BN计算公式:


    BN计算公式

    其中Xi对应x,μ即为mean,δ对应variance。第3个公式做初步的Norm,第4个公式中,γ即为scale,β对应offset。

    BN在实际中,由于mean和variance是和batch内的数据有关的,因此需要注意训练过程和预测过程中,mean和variance无法使用相同的数据。需要一个trick,即moving_average,代码如下:

    update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
    update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
    mean, variance = control_flow_ops.cond(['is_training'], lambda: (mean, variance), lambda: (moving_mean, moving_variance))
    

    在训练的过程中,通过每个step得到的mean和variance,叠加计算对应的moving_average(滑动平均),并最终保存下来以便在inference的过程中使用。
    对于assign_moving_average方法如下:

    def assign_moving_average(variable, value, decay, zero_debias=True, name=None)
    

    其实内部计算比较简单,公式表达如下:
    variable = variable * decay + value * (1 - decay)
    变换一下:
    variable = variable - (1 - decay) * (variable - value)
    减号后面的项就是moving_average的更新delta了。

    相关文章

      网友评论

          本文标题:BN(Batch Normalization)在TensorFl

          本文链接:https://www.haomeiwen.com/subject/fgmugttx.html