美文网首页
滑动平均模型

滑动平均模型

作者: 小眼大神 | 来源:发表于2019-03-13 00:20 被阅读0次

    在学习神经网络中看到了一个平均滑动模型,该方法可以使模型在测试数据上表现的更加健壮,而TensorFlow中提供的实现方法为tf.train.ExponentialMovingAverage,起初并不理解该方法为啥能使模型在测试数据上更健壮,多放查找资料之后,记录在此。

    思想

    在初始化ExponentialMovingAverage时,需要提供一个衰减率(decay)来空值模型跟新的书读。ExponentialMovingAverage会对TensorFlow中每一个变量会维护一个影子变量(shadow_variable),影子变量的初始值为变量的初始值,每次迭代时,变量进行更新之后,影子变量的值也会同步更新:

    shadow\_variable = decay * shadow\_variable + (1-decay)*variable

    从上式中可以看到,decay决定模型更新的速度,decay越大,模型越稳定。在实际应用中,decay一般是接近1的数(0.99,0.999等)。

    当decay设置较大时,模型训练比较慢,为了使模型在前期能够更新更快,ExponentialMovingAverage还提供了num_updates参数来动态设置decay大小。而此时的衰减率为:

    decay = min\left\{decay, \frac{1+num\_updates}{10+num\_updates} \right\}

    在使用梯度下降算法进行模型训练时,每次更新参数权重时,该权重的影子变量也会随着模型的训练而更新,最终稳定在一个接近真实权重值的附近。在测试集上使用影子变量替换原来的变量进行预测时,可以得到一个更好的结果。

    即,滑动平均的使用步骤为:

    1. 训练阶段:为每个可训练的权重维护影子变量,并随着迭代的进行更新;
    2. 预测阶段:使用影子变量替代真实变量值,进行预测。

    滑动平均为什么在测试过程中被使用

    训练中一直使用原来不带滑动的参数,可以得到新的参数,如此就可以更新该参数的影子变量shadow_variable。基于上面的式子可以看到,shadow_variable的更新比较平滑,对于随机梯度下降算法而言,更平滑的更新效果较好。

    代码示例

    import tensorflow as tf
    
    v1 = tf.Variable(0, dtype=tf.float32)
    step = tf.Variable(0, trainable=False)
    
    ema = tf.train.ExponentialMovingAverage(0.99, step)
    
    # 定义一个平滑平均偏亮的操作,每次执行时,会更新列表中的变量
    maintain_averages_op = ema.apply([v1])
    
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
    
        # 通过ema.average(v1)获取滑动平均之后变量的取值
        print(sess.run([v1, ema.average(v1)]))
    
        # 更新变量v1的值到5
        sess.run(tf.assign(v1, 5))
    
        # 更新v1的滑动平均值。decay = min{0.99, (1+step)/(10+step)} = 0.1
        # v1 的滑动平均更新为 0.1 * 0 + 0.9 * 5 = 4.5
        sess.run(maintain_averages_op)
        print(sess.run([v1, ema.average(v1)]))
    
        # 更新step为10000
        sess.run(tf.assign(step, 10000))
        # 更新v1的值为10
        sess.run(tf.assign(v1, 10))
        # 更新v1的滑动平均值,decay = min{0.99, (1+step)/(10+step)} = 0.99
        # v1的滑动平均更新为 0.99 * 4.5 + 0.01 * 10 = 4.555
        sess.run(maintain_averages_op)
        print(sess.run([v1, ema.average(v1)]))
    
        # 再次更新滑动平均值 0.99 * 4.555 + 0.001 * 10 = 4.60945
        sess.run(maintain_averages_op)
        print(sess.run([v1, ema.average(v1)]))
    

    相关文章

      网友评论

          本文标题:滑动平均模型

          本文链接:https://www.haomeiwen.com/subject/vrlapqtx.html