美文网首页
SENet在MXNet下的实现(部分代码)

SENet在MXNet下的实现(部分代码)

作者: 魔法少女玛格姬 | 来源:发表于2018-06-16 14:43 被阅读0次

    Squeeze-and-Excitation Networks (SENet)获得了2017年ImageNet的分类冠军。
    论文地址:https://arxiv.org/abs/1709.01507
    本文简单介绍了SENet这篇文章,并附上了SE-ResNet基于MXNet(主要基于是gluon接口)的代码实现。

    SENet中,Squeeze和Excitation是两个关键性操作,示意图如下:


    d80b0d64610e4610875850b69d68779a_th.jpg

    第一步:Squeeze是在空间维度对特征进行压缩,即Global Average Pooling。

    第二步:Excitation是用Sigmoid Function为每个特征通道生成权重,权重表示特征通道间的相关性。

    第三步:Reweight操作,将Excitation生成的权重通过乘法逐通道加权到CNN提取的特征图上,完成在通道维度上的对原始特征的重标定。

    SE模块可以简单地嵌入到任何神经网络当中,下面是SE-ResNet的网络结构图:


    SE-ResNet.png

    直接上代码:
    这是原始的Residual Block,我们拿来做个参考

    class Residual(nn.HybridBlock):
        def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
            super(Residual, self).__init__(**kwargs)
            self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
                                   strides=strides)
            self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
            if use_1x1conv:
                self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
                                       strides=strides)
            else:
                self.conv3 = None
            self.bn1 = nn.BatchNorm()
            self.bn2 = nn.BatchNorm()
    
        def forward(self, X):
            Y = nd.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            if self.conv3:
                X = self.conv3(X)
            return nd.relu(Y + X)
    

    重点在这里,SE-Module,为了方便理解我们把Squeeze和Excitation单独写:

    def Attention(num_channels):
        net = nn.HybridSequential()
        with net.name_scope():
            net.add(
                nn.GlobalAvgPool2D(),
                nn.Dense(num_channels),
                nn.Activation('relu'),
                nn.Dense(num_channels),
                nn.Activation('sigmoid')
            )
        return net
    

    再将SE-Module嵌入到Residual Block里面去,做一个broadcast_multiply

    class SEResidual(nn.HybridBlock):
        def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
            super(SEResidual, self).__init__(**kwargs)
            self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
                                   strides=strides)
            self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
            if use_1x1conv:
                self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
                                       strides=strides)
            else:
                self.conv3 = None
            self.bn1 = nn.BatchNorm()
            self.bn2 = nn.BatchNorm()
            self.weight = Attention(num_channels)
    
        def forward(self, X):
            Y = nd.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            W = Y
            for layer in self.weight: #W就是Attention的权重
                W = layer(W)
            if self.conv3:
                X = self.conv3(X)
            Y = nd.broadcast_mul(Y,nd.reshape(W,shape=(-1,num_channels,1,1)))
            return nd.relu(Y + X)
    

    最后再用SE-Residual Block搭积木就好啦。
    啾咪~

    相关文章

      网友评论

          本文标题:SENet在MXNet下的实现(部分代码)

          本文链接:https://www.haomeiwen.com/subject/amlbeftx.html