美文网首页
解决TensorFlow中Batch Normalization

解决TensorFlow中Batch Normalization

作者: nonoka | 来源:发表于2019-01-23 14:12 被阅读0次

            batch normalization的坑我真的是踩到要吐了,几个月前就踩了一次,看了网上好多资料,虽然跑通了但是当时没记录下来,结果这次又遇到了。时隔几个月,已经忘得差不多了,结果又花了半天重新踩了一遍,真是惨痛的教训。

    1 API

            在Stack Overflow[What is right batch normalization function in Tensorflow?]中有网友对TensorFlow中的batch normalization做了总结,如下,他在其中说到tf.layers.batch_normalization应该是我们的默认选择(在知乎和其他网站中,有很多网友自己实现了batch_normalization,但其实是没必要的,直接使用TensorFlow提供的API就好了):

    • tf.nn.batch_normalization is a low-level op. The caller is responsible to handle mean and variance tensors themselves.
    • tf.nn.fused_batch_norm is another low-level op, similar to the previous one. The difference is that it's optimized for 4D input tensors, which is the usual case in convolutional neural networks. tf.nn.batch_normalization accepts tensors of any rank greater than 1.
    • tf.layers.batch_normalization is a high-level wrapper over the previous ops. The biggest difference is that it takes care of creating and managing the running mean and variance tensors, and calls a fast fused op when possible. Usually, this should be the default choice for you.
    • tf.contrib.layers.batch_norm is the early implementation of batch norm, before it's graduated to the core API (i.e., tf.layers). The use of it is not recommended because it may be dropped in the future releases.
    • tf.nn.batch_norm_with_global_normalization is another deprecated op. Currently, delegates the call to tf.nn.batch_normalization, but likely to be dropped in the future.
    • Finally, there's also Keras layer keras.layers.BatchNormalization, which in case of tensorflow backend invokes tf.nn.batch_normalization.

    2 训练

            使用tf.layers.batch_normalization:

    conv1 = self.conv2d(input, 32, 7, 1)
    bn = tf.layers.batch_normalization(conv1, training=self.training)
    relu = tf.nn.relu(bn)
    

            训练时要注意两个地方:一是training=True;二是添加以下代码:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)
    

            在使用saver时选中moving_meanmoving_variance(训练和测试都要这段):

    var_list = [var for var in tf.global_variables() if "moving" in var.name]
    var_list += tf.trainable_variables()
    saver = tf.train.Saver(var_list=var_list, max_to_keep=20)
    

    3 测试

            让training=False,使用placeholder比使用Python变量更加方便:

    self.training = tf.placeholder(tf.bool)
    

            测试代码差不多像下面这样:

    _, loss_ = sess.run([optimizer, loss], feed_dict={image: train1, training: True})
    
    acc = sess.run(accuracy, feed_dict={image: train1, training:False})
    

            注意:如果你用training=False测试出来的结果和training=True的结果相差很多,说明moving_meanmoving_variance没有保存成功,请仔细检查前面的代码有没有写错。

    相关文章

      网友评论

          本文标题:解决TensorFlow中Batch Normalization

          本文链接:https://www.haomeiwen.com/subject/awbjjqtx.html