batch normalization的坑我真的是踩到要吐了,几个月前就踩了一次,看了网上好多资料,虽然跑通了但是当时没记录下来,结果这次又遇到了。时隔几个月,已经忘得差不多了,结果又花了半天重新踩了一遍,真是惨痛的教训。
1 API
在Stack Overflow[What is right batch normalization function in Tensorflow?]中有网友对TensorFlow中的batch normalization做了总结,如下,他在其中说到tf.layers.batch_normalization应该是我们的默认选择(在知乎和其他网站中,有很多网友自己实现了batch_normalization,但其实是没必要的,直接使用TensorFlow提供的API就好了):
tf.nn.batch_normalization
is a low-level op. The caller is responsible to handlemean
andvariance
tensors themselves.tf.nn.fused_batch_norm
is another low-level op, similar to the previous one. The difference is that it's optimized for 4D input tensors, which is the usual case in convolutional neural networks.tf.nn.batch_normalization
accepts tensors of any rank greater than 1.tf.layers.batch_normalization
is a high-level wrapper over the previous ops. The biggest difference is that it takes care of creating and managing the running mean and variance tensors, and calls a fast fused op when possible. Usually, this should be the default choice for you.tf.contrib.layers.batch_norm
is the early implementation of batch norm, before it's graduated to the core API (i.e.,tf.layers
). The use of it is not recommended because it may be dropped in the future releases.tf.nn.batch_norm_with_global_normalization
is another deprecated op. Currently, delegates the call totf.nn.batch_normalization
, but likely to be dropped in the future.- Finally, there's also Keras layer
keras.layers.BatchNormalization
, which in case of tensorflow backend invokestf.nn.batch_normalization
.
2 训练
使用tf.layers.batch_normalization:
conv1 = self.conv2d(input, 32, 7, 1)
bn = tf.layers.batch_normalization(conv1, training=self.training)
relu = tf.nn.relu(bn)
训练时要注意两个地方:一是training=True;二是添加以下代码:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)
在使用saver时选中moving_mean和moving_variance(训练和测试都要这段):
var_list = [var for var in tf.global_variables() if "moving" in var.name]
var_list += tf.trainable_variables()
saver = tf.train.Saver(var_list=var_list, max_to_keep=20)
3 测试
让training=False,使用placeholder比使用Python变量更加方便:
self.training = tf.placeholder(tf.bool)
测试代码差不多像下面这样:
_, loss_ = sess.run([optimizer, loss], feed_dict={image: train1, training: True})
acc = sess.run(accuracy, feed_dict={image: train1, training:False})
注意:如果你用training=False测试出来的结果和training=True的结果相差很多,说明moving_mean和moving_variance没有保存成功,请仔细检查前面的代码有没有写错。
网友评论