美文网首页深度学习
自己实现的BatchNormalize层

自己实现的BatchNormalize层

作者: 追光者876 | 来源:发表于2020-03-12 15:48 被阅读0次

个人认为BatchNormalize是一个非常重要但是却很容易被忽略的知识点,目前几乎所有的神经网络都会用到。我在用cifar10数据集测试时,发现同样的网络,有bn要比没有bn层的验证集准确率提高10%左右。这也验证了吴恩达老师在课中所讲的bn层会有轻微的正则化效果。

class BatchNormalize(tf.keras.layers.Layer):
  def __init__(self, name='BatchNormal', **kwargs):
    super(BatchNormalize, self).__init__(name=name, **kwargs)
    self._epsilon = 0.001
    self._decay = 0.99
  def build(self, input_shape):
    self._mean = self.add_weight(name='mean', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=False)
    self._variance = self.add_weight(name="variance", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=False)
    self._gamma = self.add_weight(name='gamma', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=True)
    self._beta = self.add_weight(name="beta", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=True)
    self._axes = [0, 1, 2]
    if len(input_shape) == 2:
        self._axes = [0]
  def call(self, inputs, training=None):
      if training:
          batch_mean, batch_variance = tf.nn.moments(inputs, axes=self._axes, keep_dims=False, name='moment')
          train_mean = self._mean.assign(tf.add(tf.multiply(self._mean, self._decay), tf.multiply(batch_mean, tf.math.subtract(1.0, self._decay))))
          train_variance = self._variance.assign(tf.add(tf.multiply(self._variance, self._decay), tf.multiply(batch_variance, tf.math.subtract(1.0, self._decay))))
          with tf.control_dependencies([train_mean, train_variance]):
              return tf.nn.batch_normalization(inputs, batch_mean, batch_variance, self._beta, self._gamma, self._epsilon, name="batch_normal")
      else:
          return tf.nn.batch_normalization(inputs, self._mean, self._variance, self._beta, self._gamma, self._epsilon)

相关文章

  • 自己实现的BatchNormalize层

    个人认为BatchNormalize是一个非常重要但是却很容易被忽略的知识点,目前几乎所有的神经网络都会用到。我在...

  • MVP

    View View层接口 View层实现 Presenter Presenter层接口 Presenter层实现 ...

  • Pytorch的nn.model

    nn.model是所有网络(net)层的父类,我们自己如果要实现层的话,需要继承该类。 比如:我们自己实现一个线性...

  • Dubbo-面试

    工作原理 service 层:provider 和 consumer,留给自己实现的接口 config 层:配置文...

  • 基于SSM的用户信息管理系统的增删查、登录功能及分页的实现

    SSM的增删改功能实现 controller层 DAO层 Servie接口层 Service实现层 SQL语句的实...

  • pooling层的实现

    Pooling层概述 Pooling层是CNN中的重要组成部分,通常用来实现对网络中Feature Map的降维,...

  • iOS网络层设计-Engine 实现

    iOS 网络层设计 iOS网络层设计-Client 实现 iOS网络层设计-Engine 实现 iOS 网络层 E...

  • iOS网络层设计-Client 实现

    iOS 网络层设计 iOS网络层设计-Client 实现 iOS网络层设计-Engine 实现 iOS 网络层 C...

  • Android MVP

    V层 P层 M层 V层接口实现 在Activity/Fragment中实现,并获取Peraenter对象关系流程:...

  • 定位引擎产品框架说明文档

    定位引擎产品由上至下分为四层,编码实现层、场景应用层、通用算法层、数据采集层,以及测试平台 编码实现层 主要工作:...

网友评论

    本文标题:自己实现的BatchNormalize层

    本文链接:https://www.haomeiwen.com/subject/nyjpjhtx.html