美文网首页工作生活
CBAM: Convolutional Block Attent

CBAM: Convolutional Block Attent

作者: 瞎了吗 | 来源:发表于2019-06-30 10:04 被阅读0次

    CBAM: Convolutional Block Attention Module

    CBAM依然采用了block的形式,但是在每一个block上进行更加细致的设计来使得网络的结构更加合理有效。

    结构

    在这里插入图片描述

    作者采用了类似于人类attention,也就是注意力的机制,对一个特征矩阵进行重新构造。注意力机制就是采用一种可以学习的方式来对特征重新赋予权重,权重高的特征就是注意力的注意点.

    Convolutional Block Attention Module

    上面的结构图可以看到,一个特征经过一个Channel Attention Module和一个Spatial Attention Module被重新构造,输出了一个精修过的特征矩阵。

    通道注意力

    首先是通道注意力,我们知道一张图片经过几个卷积层会得到一个特征矩阵,这个矩阵的通道数就是卷积层核的个数。那么,一个常见的卷积核经常达到1024,2048个,并不是每个通道都对于信息传递非常有用了的。因此,通过对这些通道进行过滤,也就是注意,来得到优化后的特征.
    主要思路就是:增大有效通道权重,减少无效通道的权重.
    公式如下:
    \mathbf{M}_{\mathbf{c}}(\mathbf{F})=\sigma(M L P(\text {AvgPool}(\mathbf{F}))+M L P(\operatorname{Max} \operatorname{Pool}(\mathbf{F})))
    =\sigma\left(\mathbf{W}_{\mathbf{1}}\left(\mathbf{W}_{\mathbf{o}}\left(\mathbf{F}_{\mathbf{a v g}}^{\mathbf{c}}\right)\right)+\mathbf{W}_{\mathbf{1}}\left(\mathbf{W}_{\mathbf{0}}\left(\mathbf{F}_{\max }^{\mathbf{c}}\right)\right)\right)

    通道注意力结构如下:


    在这里插入图片描述

    在通道维度上进行全局的pooling操作,再经过同一个MLP得到权重,相加作为最终的注意力向量(权重)。
    这里非常像SENet,SENet在很多论文中都被证实对效果有提升,这里的区别是,SENet采用的是平均值的pooling,这篇论文又加入了最大值pooling。作者在论文中,通过对比实验,证实max pooling提高了效果。
    注意这里的MLP的中间层较小,这个可能有助于信息的整合。

    空间注意力

    论文中,作者认为通道注意力关注的是:what,然而空间注意力关注的是:Where。
    公式如下:
    \mathbf{M}_{\mathbf{s}}(\mathbf{F})=\sigma\left(f^{7 \times 7}([\operatorname{Avg} P o o l(\mathbf{F}) ; \operatorname{MaxPool}(\mathbf{F})])\right)
    =\sigma\left(f^{7 \times 7}\left(\left[\mathbf{F}_{\mathrm{avg}}^{\mathrm{s}} ; \mathbf{F}_{\mathrm{max}}^{\mathrm{s}}\right]\right)\right)

    空间注意力结构图如下:


    在这里插入图片描述

    同样使用了avg-pooling和max-pooling来对信息进行评估,使用一个7*7的卷积来进行提取。
    注意权重都通过sigmoid来进行归一化。

    下图是一个将论文的方法应用到ResNet的例子:

    在这里插入图片描述

    将模型应用到每一个ResNet block的输出上。

    Pytorch实现CBAM

    此处代码是将其应用到resnext,可以在github链接查看代码实现

    # -*-coding:utf-8-*-
    import math
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    __all__ = ['cbam_resnext29_8x64d', 'cbam_resnext29_16x64d']
    
    
    class BasicConv(nn.Module):
        def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
            super(BasicConv, self).__init__()
            self.out_channels = out_planes
            self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                                  stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
            self.bn = nn.BatchNorm2d(out_planes, eps=1e-5,
                                     momentum=0.01, affine=True) if bn else None
            self.relu = nn.ReLU() if relu else None
    
        def forward(self, x):
            x = self.conv(x)
            if self.bn is not None:
                x = self.bn(x)
            if self.relu is not None:
                x = self.relu(x)
            return x
    
    
    class Flatten(nn.Module):
        def forward(self, x):
            return x.view(x.size(0), -1)
    
    
    class ChannelGate(nn.Module):
        def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
            super(ChannelGate, self).__init__()
            self.gate_channels = gate_channels
            self.mlp = nn.Sequential(
                Flatten(),
                nn.Linear(gate_channels, gate_channels // reduction_ratio),
                nn.ReLU(),
                nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
            self.pool_types = pool_types
    
        def forward(self, x):
            channel_att_sum = None
            for pool_type in self.pool_types:
                if pool_type == 'avg':
                    avg_pool = F.avg_pool2d(
                        x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                    channel_att_raw = self.mlp(avg_pool)
                elif pool_type == 'max':
                    max_pool = F.max_pool2d(
                        x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                    channel_att_raw = self.mlp(max_pool)
                elif pool_type == 'lp':
                    lp_pool = F.lp_pool2d(
                        x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                    channel_att_raw = self.mlp(lp_pool)
                elif pool_type == 'lse':
                    # LSE pool only
                    lse_pool = logsumexp_2d(x)
                    channel_att_raw = self.mlp(lse_pool)
    
                if channel_att_sum is None:
                    channel_att_sum = channel_att_raw
                else:
                    channel_att_sum = channel_att_sum + channel_att_raw
    
            scale = torch.sigmoid(channel_att_sum).unsqueeze(
                2).unsqueeze(3).expand_as(x)
            return x * scale
    
    
    def logsumexp_2d(tensor):
        tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
        s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
        outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
        return outputs
    
    
    class ChannelPool(nn.Module):
        def forward(self, x):
            return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
    
    
    class SpatialGate(nn.Module):
        def __init__(self):
            super(SpatialGate, self).__init__()
            kernel_size = 7
            self.compress = ChannelPool()
            self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(
                kernel_size-1) // 2, relu=False)
    
        def forward(self, x):
            x_compress = self.compress(x)
            x_out = self.spatial(x_compress)
            scale = torch.sigmoid(x_out)  # broadcasting
            return x * scale
    
    
    class CBAM(nn.Module):
        def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
            super(CBAM, self).__init__()
            self.ChannelGate = ChannelGate(
                gate_channels, reduction_ratio, pool_types)
            self.no_spatial = no_spatial
            if not no_spatial:
                self.SpatialGate = SpatialGate()
    
        def forward(self, x):
            x_out = self.ChannelGate(x)
            if not self.no_spatial:
                x_out = self.SpatialGate(x_out)
            return x_out
    
    
    class Bottleneck(nn.Module):
    
        def __init__(self, in_channels, out_channels, stride, cardinality, base_width, expansion):
    
            super(Bottleneck, self).__init__()
            width_ratio = out_channels / (expansion * 64.)
            D = cardinality * int(base_width * width_ratio)
    
            self.relu = nn.ReLU(inplace=True)
            self.cbam_module = CBAM(out_channels)
    
            self.conv_reduce = nn.Conv2d(
                in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
            self.bn_reduce = nn.BatchNorm2d(D)
            self.conv_conv = nn.Conv2d(
                D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
            self.bn = nn.BatchNorm2d(D)
            self.conv_expand = nn.Conv2d(
                D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
            self.bn_expand = nn.BatchNorm2d(out_channels)
    
            self.shortcut = nn.Sequential()
            if in_channels != out_channels:
                self.shortcut.add_module('shortcut_conv',
                                         nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0,
                                                   bias=False))
                self.shortcut.add_module(
                    'shortcut_bn', nn.BatchNorm2d(out_channels))
    
        def forward(self, x):
            out = self.conv_reduce.forward(x)
            out = self.relu(self.bn_reduce.forward(out))
            out = self.conv_conv.forward(out)
            out = self.relu(self.bn.forward(out))
            out = self.conv_expand.forward(out)
            out = self.bn_expand.forward(out)
    
            residual = self.shortcut.forward(x)
    
            out = self.cbam_module(out) + residual
            out = self.relu(out)
            return out
    
    
    class SeResNeXt(nn.Module):
        def __init__(self, cardinality, depth, num_classes, base_width, expansion=4):
            super(SeResNeXt, self).__init__()
            self.cardinality = cardinality
            self.depth = depth
            self.block_depth = (self.depth - 2) // 9
            self.base_width = base_width
            self.expansion = expansion
            self.num_classes = num_classes
            self.output_size = 64
            self.stages = [64, 64 * self.expansion, 128 *
                           self.expansion, 256 * self.expansion]
    
            self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
            self.bn_1 = nn.BatchNorm2d(64)
            self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1)
            self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2)
            self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2)
            self.fc = nn.Linear(self.stages[3], num_classes)
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight.data)
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
    
        def block(self, name, in_channels, out_channels, pool_stride=2):
            block = nn.Sequential()
            for bottleneck in range(self.block_depth):
                name_ = '%s_bottleneck_%d' % (name, bottleneck)
                if bottleneck == 0:
                    block.add_module(name_, Bottleneck(in_channels, out_channels, pool_stride, self.cardinality,
                                                       self.base_width, self.expansion))
                else:
                    block.add_module(name_,
                                     Bottleneck(out_channels, out_channels, 1, self.cardinality, self.base_width,
                                                self.expansion))
            return block
    
        def forward(self, x):
            x = self.conv_1_3x3.forward(x)
            x = F.relu(self.bn_1.forward(x), inplace=True)
            x = self.stage_1.forward(x)
            x = self.stage_2.forward(x)
            x = self.stage_3.forward(x)
            x = F.avg_pool2d(x, 8, 1)
            x = x.view(-1, self.stages[3])
            return self.fc(x)
    
    
    def cbam_resnext29_8x64d(num_classes):
        return SeResNeXt(cardinality=8, depth=29, num_classes=num_classes, base_width=64)
    
    
    def cbam_resnext29_16x64d(num_classes):
        return SeResNeXt(cardinality=16, depth=29, num_classes=num_classes, base_width=64)
    
    

    相关文章

      网友评论

        本文标题:CBAM: Convolutional Block Attent

        本文链接:https://www.haomeiwen.com/subject/irbvcctx.html