美文网首页深度残差收缩网络
【深度残差收缩网络】算法原理及TFLearn实现

【深度残差收缩网络】算法原理及TFLearn实现

作者: striving66 | 来源:发表于2019-12-30 09:59 被阅读0次

    深度残差收缩网络是一种新的神经网络结构,实际上是深度残差网络的升级版本,能够在一定程度上提高深度学习方法在含噪数据上的特征学习效果。

    首先,简要回顾一下深度残差网络,其基本模块如下图所示。相较于传统的卷积神经网络,深度残差网络利用了跨越多层的恒等映射,来缓解模型训练的难度,提高准确性。


    深度残差网络的基本模块

    然后,和深度残差网络不同的是,深度残差收缩网络引入了一个小型的子网络,用这个子网络学习得到一组阈值,对特征图的各个通道进行软阈值化。这个过程其实可以看成一个可训练的特征选择的过程。具体而言,就是通过前面的卷积层将重要的特征转换成绝对值较大的值,将冗余信息所对应的特征转换成绝对值较小的值;通过子网络学习得到二者之间的界限,并且通过软阈值化将冗余特征置为零,同时使重要的特征有着非零的输出。


    深度残差收缩网络的基本模块

    深度残差收缩网络其实是一种通用的方法,不仅可以用于含噪数据,也可以用于不含噪声的情况。这是因为,深度残差收缩网络中的阈值是根据样本情况自适应确定的。换言之,如果样本中不含冗余信息、不需要软阈值化,那么阈值可以被训练得非常接近于零,从而软阈值化就相当于不存在了。

    最后,堆叠一定数量的基本模块,就得到了完整的网络结构。


    深度残差收缩网络的整体结构

    利用深度残差收缩网络进行MNIST图像识别,可以看到,效果还是不错的。下面是深度残差收缩网络的代码:

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Thu Dec 26 07:46:00 2019
    
    Implemented using TensorFlow 1.0 and TFLearn 0.3.2
     
    M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
    IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
    
    @author: me
    """
    
    import tflearn
    import tensorflow as tf
    from tflearn.layers.conv import conv_2d
    
    # Data loading
    from tflearn.datasets import mnist
    X, Y, testX, testY = mnist.load_data(one_hot=True)
    X = X.reshape([-1,28,28,1])
    testX = testX.reshape([-1,28,28,1])
    
    def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                       downsample_strides=2, activation='relu', batch_norm=True,
                       bias=True, weights_init='variance_scaling',
                       bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                       trainable=True, restore=True, reuse=False, scope=None,
                       name="ResidualBlock"):
        
        # residual shrinkage blocks with channel-wise thresholds
    
        residual = incoming
        in_channels = incoming.get_shape().as_list()[-1]
    
        # Variable Scope fix for older TF
        try:
            vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                       reuse=reuse)
        except Exception:
            vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
    
        with vscope as scope:
            name = scope.name #TODO
    
            for i in range(nb_blocks):
    
                identity = residual
    
                if not downsample:
                    downsample_strides = 1
    
                if batch_norm:
                    residual = tflearn.batch_normalization(residual)
                residual = tflearn.activation(residual, activation)
                residual = conv_2d(residual, out_channels, 3,
                                 downsample_strides, 'same', 'linear',
                                 bias, weights_init, bias_init,
                                 regularizer, weight_decay, trainable,
                                 restore)
    
                if batch_norm:
                    residual = tflearn.batch_normalization(residual)
                residual = tflearn.activation(residual, activation)
                residual = conv_2d(residual, out_channels, 3, 1, 'same',
                                 'linear', bias, weights_init,
                                 bias_init, regularizer, weight_decay,
                                 trainable, restore)
                
                # get thresholds and apply thresholding
                abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
                scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
                scales = tflearn.batch_normalization(scales)
                scales = tflearn.activation(scales, 'relu')
                scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
                scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
                thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
                residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
                
    
                # Downsampling
                if downsample_strides > 1:
                    identity = tflearn.avg_pool_2d(identity, 1,
                                                   downsample_strides)
    
                # Projection to new dimension
                if in_channels != out_channels:
                    if (out_channels - in_channels) % 2 == 0:
                        ch = (out_channels - in_channels)//2
                        identity = tf.pad(identity,
                                          [[0, 0], [0, 0], [0, 0], [ch, ch]])
                    else:
                        ch = (out_channels - in_channels)//2
                        identity = tf.pad(identity,
                                          [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                    in_channels = out_channels
    
                residual = residual + identity
    
        return residual
    
    
    # Real-time data preprocessing
    img_prep = tflearn.ImagePreprocessing()
    img_prep.add_featurewise_zero_center(per_channel=True)
    
    # Building A Deep Residual Shrinkage Network
    net = tflearn.input_data(shape=[None, 28, 28, 1])
    net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
    net = residual_shrinkage_block(net, 1,  8, downsample=True)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, 'relu')
    net = tflearn.global_avg_pool(net)
    # Regression
    net = tflearn.fully_connected(net, 10, activation='softmax')
    mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
    net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
    # Training
    model = tflearn.DNN(net, checkpoint_path='model_mnist',
                        max_checkpoints=10, tensorboard_verbose=0,
                        clip_gradients=0.)
    
    model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
              show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')
    
    training_acc = model.evaluate(X, Y)[0]
    validation_acc = model.evaluate(testX, testY)[0]
    

    接下来是深度残差网络ResNet的代码:

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Thu Dec 26 07:46:00 2019
    
    Implemented using TensorFlow 1.0 and TFLearn 0.3.2
    K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, CVPR, 2016.
    
    @author: me
    """
    
    import tflearn
    
    # Data loading
    from tflearn.datasets import mnist
    X, Y, testX, testY = mnist.load_data(one_hot=True)
    X = X.reshape([-1,28,28,1])
    testX = testX.reshape([-1,28,28,1])
    
    # Real-time data preprocessing
    img_prep = tflearn.ImagePreprocessing()
    img_prep.add_featurewise_zero_center(per_channel=True)
    
    # Building a deep residual network
    net = tflearn.input_data(shape=[None, 28, 28, 1])
    net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
    net = tflearn.residual_block(net, 1,  8, downsample=True)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, 'relu')
    net = tflearn.global_avg_pool(net)
    # Regression
    net = tflearn.fully_connected(net, 10, activation='softmax')
    mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
    net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
    # Training
    model = tflearn.DNN(net, checkpoint_path='model_mnist',
                        max_checkpoints=10, tensorboard_verbose=0,
                        clip_gradients=0.)
    
    model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
              show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')
    
    training_acc = model.evaluate(X, Y)[0]
    validation_acc = model.evaluate(testX, testY)[0]
    

    上述两个程序构建了只有1个基本模块的小型网络,MNIST数据集中没有添加噪声,每次运行结果会有些不同。准确率如下表所示,可以看到,即使是对于不含噪声的数据,深度残差收缩网络的结果也是不错的:


    实验结果

    转载网址:

    https://my.oschina.net/u/4223274/blog/3148949

    参考文献:

    M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898

    https://ieeexplore.ieee.org/document/8850096

    相关文章

      网友评论

        本文标题:【深度残差收缩网络】算法原理及TFLearn实现

        本文链接:https://www.haomeiwen.com/subject/vlupoctx.html