美文网首页
(五)tensorflow 1.x中关于BN层的坑

(五)tensorflow 1.x中关于BN层的坑

作者: 神经网络爱好者 | 来源:发表于2020-05-21 14:31 被阅读0次

    在最近进行模型训练时,遇到了一些BN层的坑,特此记录一下。

    问题描述
    模型训练的时候,训练集上的准确率很高,测试集的表现很差,排除了其他原因后,锁定在了slim的batch_norm的使用上。

    解决方案

    1、设置依赖

    slim.batch_norm源码
    Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
          train_op = optimizer.minimize(loss)
    

    One can set updates_collections=None to force the updates in place, but that can have a speed penalty, especially in distributed settings.

    在训练时,moving_mean 和 moving_variance 默认是添加到tf.GraphKeys.UPDATE_OPS 中的, 因此需要作为一个依赖项,在更新train_op时跟新参数。将 updates_collections参数设置为None,这样会在训练时立即更新,影响速度。

    2、设置decay参数

    Lower decay value (recommend trying decay=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability.

    由于使用BN层的网络,预测的时候要用到估计的总体均值和方差,如果iteration还比较少的时候就急着去检验或者预测的话,可能这时EMA估计得到的总体均值/方差还不accurate和stable, 所以会造成训练和预测悬殊,这种情况就是造成下面这个issue的原因:https://github.com/tensorflow/tensorflow/issues/7469 解决的办法就是:当训练结果远好于预测的时候,那么可以通过减小decay,早点“热身”。
    默认decay=0.999,一般建议使用0.9

    3、模型保存

    当我们使用batch_norm时,slim.batch_norm中的moving_mean和moving_variance不是trainable的, 所以使用saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)无法保存, 应该改为:

    var_list = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars
    saver = tf.train.Saver(var_list=var_list, max_to_keep=3)
    

    相关文章

      网友评论

          本文标题:(五)tensorflow 1.x中关于BN层的坑

          本文链接:https://www.haomeiwen.com/subject/utxnohtx.html