美文网首页
tensorflow 中batch normalize 的使用

tensorflow 中batch normalize 的使用

作者: LanWong | 来源:发表于2019-08-27 15:31 被阅读0次

    最近在学习slim,slim有个很好的地方就是:搭建网络方便,也有很多预训练模型下载。

    但是最近在调slim中的resnet的时候,发现训练集有很高的accuracy(如90%),但是测试集的accuracy还是很低(如0%, 1%),这肯定不是由于欠拟合或者过拟合导致的。

    因为batch_norm 在test的时候,用的是固定的mean和var, 而这个固定的mean和var是通过训练过程中对mean和var进行移动平均得到的。而直接使用train_op会使得模型没有计算mean和var,因此正确的方式是:

    每次训练时应当更新一下moving_mean和moving_var

    optimizer = tf.train.MomentumOptimizer(lr,momentum=FLAGS.momentum,

                                          name='MOMENTUM')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies([tf.group(*update_ops)]):

        # train_op = slim.learning.create_train_op(total_loss, optimizer, global_step)

        train_op = optimizer.minimize(total_loss, global_step=global_step)

    这样在测试的时候即使将is_training改成False也能得到正常的test accuracy了。

    相关文章

      网友评论

          本文标题:tensorflow 中batch normalize 的使用

          本文链接:https://www.haomeiwen.com/subject/ncpeectx.html