美文网首页人工智能
PyTorch实现经典网络之ResNet

PyTorch实现经典网络之ResNet

作者: HaloZhang | 来源:发表于2020-12-11 22:33 被阅读0次

    简介

    深度残差网络(Deep residual network, ResNet)的提出是CNN图像史上的一件里程碑事件,让我们先看一下ResNet在ILSVRC和COCO数据集上的战绩:

    ResNet取得了5项第一,并又一次刷新了CNN模型在ImageNet上的历史成绩。ResNet的主要创新点在于设计了一种使用了Shortcut Connection的残差结构,使得网络可以设计的很深,有效解决了梯度消失问题并且同时提升了性能。


    网络退化问题

    深度卷积神经网络在图像识别领域取得了一系列重大的突破。深度神经网络以端到端的多层方式集成了低级、中级、高级的特征以及分类器,通过增加网络层数,网络可以进行更加复杂的特征提取。最近的一些证据表明网络深度对模型的性能至关重要,在ImageNet数据集上的表现良好的模型普遍层数都较大。即便是在一些非视觉识别的任务上,深度模型也带来了很大的好处。

    随之而来的一个问题是,网络深度越深性能就一定越好吗?实际上臭名昭著的梯度消失、爆炸问题从一开始就阻碍了模型收敛,这使得深度越深的模型越难训练。虽然目前已有的一些手段比如BatchNorm可以有效缓解这个问题,但是随后出现的网络模型退化(degradation)问题却更加棘手。随着网络深度增加,准确率开始趋于饱和并且快速下降。出乎意料的是,这种下降并不是由于过拟合导致的。作者通过实验证明,在一个适当的模型上单纯添加了更多层就会导致更高的训练错误率,如下图所示:
    作者分别使用了20层和56层的网络结构在CIFAR-10数据集上进行对比实验,可以看到随着网络层数加深,训练错误率和测试错误率反而越高。

    残差结构

    深度网络的退化问题表明不是所有的系统都容易优化。假设我们现在有一个浅层网络,我们再通过以下方式构造一个对应的深层网络。这个深层网络首先复制已经训练好的浅层网络,其次再往上堆叠更多的恒等映射(Identity mapping)层,即这些新增的层什么都不学习。在这种情况下,这个深层网络应该至少和浅层网络性能一样,也不应该出现退化现象。但是实验表明我们目前掌握的方法无法构造出这种对应的深层网络(也有可能是无法在有限时间内找到)。
    为此,论文作者提出了残差学习来解决网络退化问题。对于一个堆积层结构(由几层叠加组成),当输入为x时,传统方式是期望它学到的特征为H(x)。但是对于残差网络而言,它期望这个堆积层学到的特征为F(x),其中F(x) = H(x) - x,即这个堆积层学到的特征F(x)可以看成是在学习实际输出H(x)和输入x之间的残差,所以命名为残差模块。那么原始输出H(x) = F(x)+x。作者认为学习残差特征F(x)会比直接学习原始特征H(x)更容易。在极端情况下,当残差F(x)=0时,此时堆积层仅仅做了恒等映射,即这些堆积的层不会引起网络性能下降。当然实际上残差也不会为0,这也会使得残差结构可以在输入特征的基础上学习到新的特征,从而即加大了网络深度并且学习了更复杂的特征,但同时又不会引起网络性能下降。
    残差网络结构如下:

    残差结构
    其中右边的曲线就是代表的恒等映射,它跳过了2个层,直接从输入连接到了输出,有点类似电路中的短路连接(shortcut connection)。这种短路连接既不需要额外的参数,也不会增加计算复杂度。整个网络仍然可以使用SGD算法搭配反向传播来进行端到端的训练。

    这里简单分析一下为什么残差学习相对容易,从直观上看,让网络直接学习x → H(x)-x的映射,会比让网络直接学习x→H(x)的映射所学的内容少。因为残差一般比较小,学习难度小一点。下面从数学的角度来分析这个问题,残差模块可以表示为:
    y_l = h(x_l) + F(x_l, W_l)\\ x_{l+1} = f(y_l)
    其中x_lx_{l+1}表示第l个残差单元的输入和输出,注意每一个残差单元一般包含多层结构。F是残差函数,表示残差网络学习到的残差。h函数代表的是恒等映射,即上图中的曲线部分,那么有h(x_l) = x_lf是ReLU激活函数。基于上式,我们求得网络从浅层l到深层L学习到的特征为:
    x_L = x_l + \sum_{i=l}^{L-1}F(x_i,W_i)
    利用链式法则,可以求得反向过程的梯度:
    \frac{\partial l o s s}{\partial x_{l}}=\frac{\partial l o s s}{\partial x_{L}} \cdot \frac{\partial x_{L}}{\partial x_{l}}=\frac{\partial l o s s}{\partial x_{L}} \cdot\left(1+\frac{\partial}{\partial x_{l}} \sum_{i=l}^{L-1} F\left(x_{i}, W_{i}\right)\right) \
    其中注意看小括号中的部分,其中的1表明短路机制可以无损地传播梯度,而另外一项残差则需要继续经过链式法则求导获得残差梯度再传播。而残差梯度也不会那么巧刚好为-1,这就意味着总体梯度不太可能每次都为0,因此使得网络变得更加容易学习。
    完整的内容可以参考论文《Identity Mappings in Deep Residual Networks》


    网络结构

    ResNet网络结构主要参考了VGG19网络,在其基础上通过短路连接加上了残差单元。ResNet大多使用3x3的卷积核并且遵循以下两条设计原则:

    1. 对于同样的输出feature map大小,每层拥有同样数量的filters。
    2. 当feature map的大小降低一半时,feature map的数量增加一倍,以保持网络的复杂度。
    ResNet34结构如下: ResNet网络模型图

    上图中最左边是VGG-19网络,中间是朴素ResNet-34网络,右边是包含残差单元的ResNet-34网络。其中ResNet相比普通网络在每两层之间添加了短路机制,这就形成了残差学习。虚线表示的是feature map的数量发生了变化。

    下面是不同深度的ResNet网络的架构参数描述表: ResNet架构参数表
    其中以ResNet34为例,红色部分代表的是不同残差层的残差单元的数量。

    残差单元

    上图中进行的是两层间的残差学习,当网络更深的时候,可以进行3层之间的残差学习。下面是不同的残差单元示意图: 2种不同的残差单元

    网络结构剖析

    接下来以ResNet-34为例,一层一层地分析它的结构,首先从另外一个角度来看一下ResNet-34。 ResNet-34

    我们的输入图像是224x224,首先通过1个卷积层,接着通过4个残差层,最后通过Softmax之中输出一个1000维的向量,代表ImageNet的1000个分类。

    1.卷积层1

    ResNet的第一步是将图像通过一个名为Conv1的块,这个块包含卷积操作、批量归一化、最大池化操作。

    首先是卷积操作,在ResNet架构参数表中可以看到Conv1块的卷积核的大小是7x7,并且要注意到,这里进行卷积操作的时候设置padding大小为3,stride为2,故最后输出的图像大小为112。又特征图数量为64,故最后输出包含了64个通道,最终大小为112x112x64。下图展示了完整的计算过程(为了简化,这里省略掉了批量归一化操作,其实它也并不改变输出的大小。): Conv1操作
    最大池化操作的时候设置padding大小为2,步长为2,池化块大小为3,因此得到最后输出大小为56。完整计算过程见下图: 最大池化操作

    2.残差层

    我们先来解释一个名词,块。ResNet的每一层都包含若干个块。这是因为ResNet网络深度的加大是通过增加一个块中的操作来实现的,而总体的层数仍然保持不变。这里所说的一个块中的操作通常指的是对输入进行卷积操作、批量归一化操作以及通过ReLU激活函数,当然除了最后一个块,因为它不包含ReLU激活函数。

    块操作

    我们先来描述一下一个块中的操作是怎样的?见下图: 第一个残差单元中的第一个操作

    经过Conv1层之后,我们的输入变为了56x56,接着通过查看ResNet架构参数表中可得,使用的是[3x3,64]的卷积核,输出大小是56x56。我们需要注意的是,在一个block中进行的操作是不会改变输入大小的。这是因为我们设置padding为1,并且步长也设置为1。所以得到的输出大小与输入一致。

    接下来我们展示一下一个包含2层的完整残差单元的计算示意图如下(卷积核为2x[3x3,64]): 第一层中的第一个残差单元
    上图的左半部分代表的是实际计算过程,右图对应的是ResNet模型框架图中的部分。
    同理,3个残差单元堆叠起来之后的计算示意图如下(卷积核为3x[3x3,64]):
    ResNet网络结构图中的其他层也类似,只要知道其中一层的残差单元计算方式,我们很容易就可以推广到整个网络结构中去。 如果我们仔细观察每一层的第一个操作,我们会发现第一个操作使用的stride设置为2,而其余操作的stride设置为1。这意味着网络是通过增大步长来进行下采样的,而不是像传统CNN网络那样通过池化操作来进行。实际上,只有Conv1层中使用了一个最大池化操作,以及在ResNet末尾的全连接层之前执行了一个平均池化操作。 降维操作

    上图的红色部分代表的是第三和第四层中的第一个残差单元,蓝色部分代表的残差单元中的第一个块操作,可以看到stride设置为2,而其余均为默认值1。

    再看一下上图,模型架构中的虚线代表的是要改变输入的维度,对于短路连接,当输入和输出维度一致时,可以直接将输入加到输出上。但是当维度不一致时,这就不能直接相加。注意看ResNet网络模型图,每个不同颜色代表的不同的层,不同层之间的输入和输出大小是不一样的,因此不能直接相加,实际上每个不同层所做的第一个操作就是降低维度。关于降低维度主要有两种策略:

    1. 采用zero-padding增加维度,此时一般要先做一个downsamp,可以采用strde=2的pooling,这样不会增加参数。
    2. 采用新的映射(Projection Shortcut),一般采用1x1的卷积,这样会增加参数,也会增加计算量。

    下面展示一下Projection Shortcut方式的计算过程。以下图为例,输入为56x56x64,输出为28x28x128,选择3x3大小的卷积核,通过设置stride为2,padding为1,得到输出大小为28x28。


    padding方式

    接着采用1x1的卷积,stride设置为1,padding设置为0,得到的输出大小为28x28。


    Projection Shortcut 经过上述2个操作之后,每层中的第一个残差单元的整体计算流程如下,此时残差输出和Projetion Shortcut的输出大小是一致的,可以直接相加。 第二层的第一个残差单元

    下面这张示意图展示了ResNet第二层的整体计算过程。


    第二层计算示意图
    接下来的3、4层计算流程也是一样的,就不再赘述。

    实验结果

    下图是ResNet与其他模型在ImageNet数据集上的结果对比,可以看到ResNet-152在Top-1和Top-5的错误率上均达到了SOTA,再仔细观察下ResNet网络自身之间的对比,也可以发现随着层数的增加,错误率持续降低,可见ResNet有效地解决了层数增加带来的副作用。 ResNet与其他网络结果对比

    代码实践

    网络模型定义相关代码,主要定义了BasicBlock类,即包含2个卷积块的残差单元;Bottleneck类,即包含了3个卷积块的残差单元;以及ResNet类,定义了整个网络结构。完整代码如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class BasicBlock(nn.Module):
        # 2层的残差单元
        expansion = 1
        def __init__(self, in_planes, planes, stride=1):
            super(BasicBlock, self).__init__()
    
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            # 第二个卷积操作不改变维度和输出大小,因为stride=1 padding=1
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(planes)
    
            self.shortcut = nn.Sequential()
            # 如果步长不为1,或者输入与输出通道不一致,则需要进行Projection Shortcut操作
            if stride != 1 or in_planes != self.expansion*planes:
                # Projection Shortcut
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion*planes)
                )
    
        def forward(self, x):
            # 依次通过两个卷积层,和shortcut连接层,再累加起来。
            out = F.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            out += self.shortcut(x)
            out = F.relu(out)
            return out
    
    class Bottleneck(nn.Module):
        # 3层的残差单元
        expansion = 4
        def __init__(self, in_planes, planes, stride=1):
            super(Bottleneck, self).__init__()
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                                   stride=stride, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(planes)
            self.conv3 = nn.Conv2d(planes, self.expansion *
                                   planes, kernel_size=1, bias=False)
            self.bn3 = nn.BatchNorm2d(self.expansion*planes)
    
            self.shortcut = nn.Sequential()
            if stride != 1 or in_planes != self.expansion*planes:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion*planes,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion*planes)
                )
    
        def forward(self, x):
            out = F.relu(self.bn1(self.conv1(x)))
            out = F.relu(self.bn2(self.conv2(out)))
            out = self.bn3(self.conv3(out))
            out += self.shortcut(x)
            out = F.relu(out)
            return out
    
    class ResNet(nn.Module):
        def __init__(self, config):
            super(ResNet, self).__init__()
            self._config = config
            # 默认输入通道为64
            self.in_channels = 64
    
            # 代表ResNet中的Conv1卷积层
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
    
            # 分别代表ResNet中的4层
            self.layer1 = self._make_layer(config['block_type'], 64, config['num_blocks'][0], stride=1)
            self.layer2 = self._make_layer(config['block_type'], 128, config['num_blocks'][1], stride=2)
            self.layer3 = self._make_layer(config['block_type'], 256, config['num_blocks'][2], stride=2)
            self.layer4 = self._make_layer(config['block_type'], 512, config['num_blocks'][3], stride=2)
            self.linear = nn.Linear(512 * config['block_type'].expansion, config['num_classes'])
    
        def _make_layer(self, block, planes, num_blocks, stride):
            strides = [stride] + [1]*(num_blocks-1)
            layers = []
            for stride in strides:
                layers.append(block(self.in_channels, planes, stride))
                self.in_channels = planes * block.expansion
            return nn.Sequential(*layers)
    
        def forward(self, x):
            out = F.relu(self.bn1(self.conv1(x)))
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
            return out
    
        def saveModel(self):
            torch.save(self.state_dict(), self._config['model_name'])
    
        def loadModel(self, map_location):
            state_dict = torch.load(self._config['model_name'], map_location=map_location)
            self.load_state_dict(state_dict, strict=False)
    

    配置模型参数定义ResNet-18网络,设置batch size为500,训练轮次20,采用Adam优化算法,学习率设置为0.0001。
    测试相关代码如下:

    import torch
    from ResNet.network import ResNet
    from ResNet.network import BasicBlock
    from ResNet.network import Bottleneck
    from ResNet.trainer import Trainer
    from ResNet.dataloader import LoadCIFAR10
    from ResNet.dataloader import Construct_DataLoader
    from torch.autograd import Variable
    
    resnet_config = \
    {
        'block_type': BasicBlock,
        'num_blocks': [2,2,2,2], #ResNet18
        'num_epoch': 20,
        'batch_size': 500,
        'lr': 1e-3,
        'l2_regularization':1e-4,
        'num_classes': 10,
        'device_id': 0,
        'use_cuda': True,
        'model_name': '../TrainedModels/ResNet18.model'
    }
    
    if __name__ == "__main__":
        ####################################################################################
        # ResNet 模型
        ####################################################################################
        train_dataset, test_dataset = LoadCIFAR10(True)
        # define ResNet model
        resNet = ResNet(resnet_config)
    
        ####################################################################################
        # 模型训练阶段
        ####################################################################################
        # 实例化模型训练器
        trainer = Trainer(model=resNet, config=resnet_config)
        # 训练
        trainer.train(train_dataset)
        # 保存模型
        trainer.save()
    
        ####################################################################################
        # 模型测试阶段
        ####################################################################################
        resNet.eval()
        if resnet_config['use_cuda']:
            resNet.loadModel(map_location=torch.device('cpu'))
            resNet = resNet.cuda()
        else:
            resNet.loadModel(map_location=lambda storage, loc: storage.cuda(resnet_config['device_id']))
    
        correct = 0
        total = 0
        for images, labels in Construct_DataLoader(test_dataset, resnet_config['batch_size']):
            images = Variable(images)
            labels = Variable(labels)
            if resnet_config['use_cuda']:
                images = images.cuda()
                labels = labels.cuda()
    
            y_pred = resNet(images)
            _, predicted = torch.max(y_pred.data, 1)
            total += labels.size(0)
            temp = (predicted == labels.data).sum()
            correct += temp
        print('Accuracy of the model on the test images: %.2f%%' % (100.0 * correct / total))
    

    测试结果

    训练和测试都是在CIFAR-10小型图像数据集上进行,经过20次迭代之后,在训练集上得到97.96%的准确率,在测试集上得到81.41%的准确率。通过参数调整还可以达到更高的准确率。 ResNet-18在CIFAR-10训练集和测试集上的准确率

    完整代码见https://github.com/HeartbreakSurvivor/ClassicNetworks/tree/master/ResNet


    参考

    相关文章

      网友评论

        本文标题:PyTorch实现经典网络之ResNet

        本文链接:https://www.haomeiwen.com/subject/tiqxgktx.html