美文网首页
2.4 ShuffleNet V1思考

2.4 ShuffleNet V1思考

作者: 深度学习模型优化 | 来源:发表于2019-04-26 01:41 被阅读0次

    1 ShuffleNet的思想

    ShuffleNet使用Group convolution和Channel shuffle改进ResNet,可以看作是ResNet的压缩版本。

    • Group convolution
    • Channel shuffle
    图1 ShuffleNet的微结构

    ShuffleNet的本质是将卷积运算限制在每个Group内,这样模型的计算量取得了显著的下降。然而导致模型的信息流限制在各个Group内,组与组之间没有信息交换,这会影响模型的表示能力。因此,需要引入组间信息交换的机制,即Channel Shuffle操作。同时Channel Shuffle是可导的,可以实现end-to-end一次性训练网络。

    2 核心代码

    分组shuffle通道:

    def shuffle_channels(x, groups):
        """shuffle channels of a 4-D Tensor"""
        batch_size, channels, height, width = x.size()
        assert channels % groups == 0
        channels_per_group = channels // groups
        # split into groups
        x = x.view(batch_size, groups, channels_per_group,
                   height, width)
        # transpose 1, 2 axis
        x = x.transpose(1, 2).contiguous()
        # reshape into orignal
        x = x.view(batch_size, channels, height, width)
        return x
    

    ShuffleNet的A单元:

    class ShuffleNetUnitA(nn.Module):
        """ShuffleNet unit for stride=1"""
        def __init__(self, in_channels, out_channels, groups=3):
            super(ShuffleNetUnitA, self).__init__()
            assert in_channels == out_channels
            assert out_channels % 4 == 0
            bottleneck_channels = out_channels // 4
            self.groups = groups
            self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
                                            1, groups=groups, stride=1)
            self.bn2 = nn.BatchNorm2d(bottleneck_channels)
            self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
                                             bottleneck_channels,
                                             3, padding=1, stride=1,
                                             groups=bottleneck_channels)
            self.bn4 = nn.BatchNorm2d(bottleneck_channels)
            self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
                                         1, stride=1, groups=groups)
            self.bn6 = nn.BatchNorm2d(out_channels)
    
        def forward(self, x):
            out = self.group_conv1(x)
            out = F.relu(self.bn2(out))
            out = shuffle_channels(out, groups=self.groups)
            out = self.depthwise_conv3(out)
            out = self.bn4(out)
            out = self.group_conv5(out)
            out = self.bn6(out)
            out = F.relu(x + out)
            return out
    

    ShuffleNet的B单元:

    class ShuffleNetUnitB(nn.Module):
        """ShuffleNet unit for stride=2"""
        def __init__(self, in_channels, out_channels, groups=3):
            super(ShuffleNetUnitB, self).__init__()
            out_channels -= in_channels
            assert out_channels % 4 == 0
            bottleneck_channels = out_channels // 4
            self.groups = groups
            self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
                                         1, groups=groups, stride=1)
            self.bn2 = nn.BatchNorm2d(bottleneck_channels)
            self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
                                             bottleneck_channels,
                                             3, padding=1, stride=2,
                                             groups=bottleneck_channels)
            self.bn4 = nn.BatchNorm2d(bottleneck_channels)
            self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
                                         1, stride=1, groups=groups)
            self.bn6 = nn.BatchNorm2d(out_channels)
    
        def forward(self, x):
            out = self.group_conv1(x)
            out = F.relu(self.bn2(out))
            out = shuffle_channels(out, groups=self.groups)
            out = self.depthwise_conv3(out)
            out = self.bn4(out)
            out = self.group_conv5(out)
            out = self.bn6(out)
            x = F.avg_pool2d(x, 3, stride=2, padding=1)
            out = F.relu(torch.cat([x, out], dim=1))
            return out
    

    3 优缺点分析

    缺点:

    • Shuffle channel在实现的时候需要大量的指针跳转和Memory set,这本身就是极其耗时的;同时又特别依赖实现细节,导致实际运行速度不会那么理想。
    • Shuffle channel规则是人工设计出来的,不是网络自己学出来的。这不符合网络通过负反馈自动学习特征的基本原则,又陷入人工设计特征的老路(如sift/HOG等)。

    4 总结

    首先介绍了ShuffleNet的基本思想,然后介绍了ShuffleNet的核心代码实现,最后分析了ShuffleNet V1的缺点,指明了改进方向。
    最后插一句,想得到未必能做到,能做到未必能做好,能做好未必能产生效益
    告诉自己,想到了就要做到,做到了就要尽力做好,做好了就要想办法变现

    相关文章

      网友评论

          本文标题:2.4 ShuffleNet V1思考

          本文链接:https://www.haomeiwen.com/subject/ucsdgqtx.html