最近在学习slim,slim有个很好的地方就是:搭建网络方便,也有很多预训练模型下载。
但是最近在调slim中的resnet的时候,发现训练集有很高的accuracy(如90%),但是测试集的accuracy还是很低(如0%, 1%),这肯定不是由于欠拟合或者过拟合导致的。
因为batch_norm 在test的时候,用的是固定的mean和var, 而这个固定的mean和var是通过训练过程中对mean和var进行移动平均得到的。而直接使用train_op会使得模型没有计算mean和var,因此正确的方式是:
每次训练时应当更新一下moving_mean和moving_var
optimizer = tf.train.MomentumOptimizer(lr,momentum=FLAGS.momentum,
name='MOMENTUM')
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([tf.group(*update_ops)]):
# train_op = slim.learning.create_train_op(total_loss, optimizer, global_step)
train_op = optimizer.minimize(total_loss, global_step=global_step)
这样在测试的时候即使将is_training改成False也能得到正常的test accuracy了。
网友评论