美文网首页深度学习
自己实现的BatchNormalize层

自己实现的BatchNormalize层

作者: 追光者876 | 来源:发表于2020-03-12 15:48 被阅读0次

    个人认为BatchNormalize是一个非常重要但是却很容易被忽略的知识点,目前几乎所有的神经网络都会用到。我在用cifar10数据集测试时,发现同样的网络,有bn要比没有bn层的验证集准确率提高10%左右。这也验证了吴恩达老师在课中所讲的bn层会有轻微的正则化效果。

    class BatchNormalize(tf.keras.layers.Layer):
      def __init__(self, name='BatchNormal', **kwargs):
        super(BatchNormalize, self).__init__(name=name, **kwargs)
        self._epsilon = 0.001
        self._decay = 0.99
      def build(self, input_shape):
        self._mean = self.add_weight(name='mean', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=False)
        self._variance = self.add_weight(name="variance", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=False)
        self._gamma = self.add_weight(name='gamma', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=True)
        self._beta = self.add_weight(name="beta", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=True)
        self._axes = [0, 1, 2]
        if len(input_shape) == 2:
            self._axes = [0]
      def call(self, inputs, training=None):
          if training:
              batch_mean, batch_variance = tf.nn.moments(inputs, axes=self._axes, keep_dims=False, name='moment')
              train_mean = self._mean.assign(tf.add(tf.multiply(self._mean, self._decay), tf.multiply(batch_mean, tf.math.subtract(1.0, self._decay))))
              train_variance = self._variance.assign(tf.add(tf.multiply(self._variance, self._decay), tf.multiply(batch_variance, tf.math.subtract(1.0, self._decay))))
              with tf.control_dependencies([train_mean, train_variance]):
                  return tf.nn.batch_normalization(inputs, batch_mean, batch_variance, self._beta, self._gamma, self._epsilon, name="batch_normal")
          else:
              return tf.nn.batch_normalization(inputs, self._mean, self._variance, self._beta, self._gamma, self._epsilon)
    

    相关文章

      网友评论

        本文标题:自己实现的BatchNormalize层

        本文链接:https://www.haomeiwen.com/subject/nyjpjhtx.html