美文网首页
TORCH03-06ResNet网络

TORCH03-06ResNet网络

作者: 杨强AT南京 | 来源:发表于2020-04-15 08:34 被阅读0次

      由于很多学习算法的高度抽象与提炼,在工程实现与应用应该都是非常简单。残差网络可以再次证明。尽管本文实现的ResNet18没有官方实现的那么稳健,但我们使用线性方式实现,更容易理解。最终实现的ResNet18我们是与官方的网络结构保持完全一致。
      残差网络的特点是训练残差。


    • 残差的概念在集成算法中有人用过,就是GradientBoosting算法中的思想。

      • 学习过程中学习残差分类器,使用残差作为作为训练数据,训练残差分类器。所有残差分类器最后集成是最好的分类器。
    • 残差与误差的概念

      1. 残差:测量值 & 预测值 之间的差异
      2. 误差:测量值 & 真实值 之间的差异

    关于ResNet残差网络

    • ResNet是2015年的ImageNet竞赛的冠军,由微软研究院何凯明,张翔宇,任少卿,孙剑等共同提出,通过引入residual block能够成功地训练高达152层的神经网络。
      • 残差网络从残差的角度,解决了深度神经网络的退化问题。

    网络退化

    • 网络退化:
      • 网络退化是指深度神经网络冗余造成的。
        1. 网络冗余:是指本来50层的网络可以达到最大优化,但实际设计的时候设计了100层网络,其中后面的50层就是冗余层(因为梯度消失的缘故造成的)。更关键是我们设计网络的时候也不知道到底多少层为最优网络结构,所以网络冗余是必然发生的。
        2. 网络退化:网络退化不是指过拟合,而是指按照深度脑神经网络的设计,就算后面的神经网络层是冗余的,但是训练可以保证冗余的神经层被训练成恒等层(输入与输出完全一样),但实际上网络没有棒法训练出恒等层,这种现象就是退化现象(所谓退化现象就是无法训练出恒等层的情况)

    残差网络的思想

    • 传统网络如果出现冗余层,实际是训练一个权重使得这层的满足输入等于输出,假设这种使用函数h表示,就是h(x)= x。这种训练方式比较难,

    • 残差网络换一种思维模式,设计一种网络形式:h(x) = x + F(x),训练F(x) = 0

      • why:F(x)=0,因为权重初始化一般以0为中心取Gauss随机分布或者均匀分布,把权重训练为0,比训练h(x) = x的权重系数容易得多。
      • F(x)称为残差,就是残差项。
    1. 传统训练示意图
    传统训练W
    1. 残差训练示意图
    残差训练W
    • 现在的重点就是怎么设计残差项。

    残差项(残差块)设计

    • 残差项替换为恒等项的思维方式就是就是残差网络的设计核心。

      • ResNet使用卷积层作为残差项,而且是两层。
        • F(x) = W_2 \sigma (W_1x)
      • 实际上在计算的时候,就是跳跃计算,每隔2层的输出:
        • y =\sigma(F(x) + x)
    • 下面使用跳跃的方式来表达残差网络

    经典的残差网络结构示意图:来自网络
    • 不同层数的残差网络设计
    18,34,50,101,152层的残差网络设计
    • 输入图像的大小:

      • 224 \times 224 \times 3或者 224 \times 224 \times 1
    • 所有的残差网络前后各有一个公共层:

      • 公共层前:
        • 输入大小:\color{red}{224 \times 224 \times 3}
        • 卷积层:kernel = 7 \times 7, channel= 64 ,stride = 2
        • 最大池化层,kernel = 3 \times 3, stride = 2
        • 输出大小: \color{blue}{112 \times 112 \times 64}
      • 公共层后-全连接层:
        • 均值池化
        • 1000-d全连接
      • 输出函数:
        • softmax
    • 下面以18层残差网络为例说明,其他层数的残差网络类似。

    18层残差网络设计

    1. 残差层(数量2)

      • 残差层-1
        • 输入: \color{red}{112 \times 112 \times 64}
        • 参数1:kernel = 3 \times 3, channel = 64, stride = 2
        • 中间输出: \color{green}{56 \times 56 \times 64}
        • 参数2:kernel = 3 \times 3, channel = 64, stride = 1
        • 输出:\color{blue}{56 \times 56 \times 64}
      • 残差层-2
        • 输入: \color{red}{56 \times 56 \times 64}
        • 参数1:kernel = 3 \times 3, channel = 64, stride = 1
        • 参数2:kernel = 3 \times 3, channel = 64, stride = 1
        • 输出:\color{blue}{56 \times 56 \times 64}
    2. 残差层(数量2)

      • 残差层-1
        • 输入: \color{red}{56 \times 56 \times 64}
        • 参数1:kernel = 3 \times 3, channel = 128, stride = 2
        • 中间输出: \color{green}{28 \times 28 \times 128}
        • 参数2:kernel = 3 \times 3, channel = 128, stride = 1
        • 输出:\color{blue}{28 \times 28 \times 128}
      • 残差层-2
        • 输入: \color{red}{28 \times 28 \times 128}
        • 参数1:kernel = 3 \times 3, channel = 128, stride = 1
        • 参数2:kernel = 3 \times 3, channel = 128, stride = 1
        • 输出:\color{blue}{28 \times 28 \times 128}
    3. 残差层(数量2)

      • 残差层-1
        • 输入: \color{red}{28 \times 28 \times 128}
        • 参数1:kernel = 3 \times 3, channel = 256, stride = 2
        • 中间输出: \color{green}{14 \times 14 \times 256}
        • 参数2:kernel = 3 \times 3, channel = 256, stride = 1
        • 输出:\color{blue}{14 \times 14 \times 256}
      • 残差层-2
        • 输入: \color{red}{14 \times 14 \times 256}
        • 参数1:kernel = 3 \times 3, channel = 256, stride = 1
        • 参数2:kernel = 3 \times 3, channel = 256, stride = 1
        • 输出:\color{blue}{14 \times 14 \times 256}
    4. 残差层2(数量4个卷积层)

      • 残差层-1
        • 输入: \color{red}{14 \times 14 \times 256}
        • 参数1:kernel = 3 \times 3, channel = 512, stride = 2
        • 中间输出: \color{green}{7 \times 7 \times 512}
        • 参数2:kernel = 3 \times 3, channel = 512, stride = 1
        • 输出:\color{blue}{7 \times 7 \times 512}
      • 残差层-2
        • 输入: \color{red}{7 \times 7 \times 512}
        • 参数1:kernel = 3 \times 3, channel = 512, stride = 1
        • 参数2:kernel = 3 \times 3, channel = 512, stride = 1
        • 输出:\color{blue}{7 \times 7 \times 512}
    5. 下采样

      • x + F(x)的时候,xF(x)不同型,需要对x下采样后与F(x)同型。
      • ResNet采用的下采样使用的是卷积下采样降维。
    • 提示
      • 其实全连接神经网络也可以设计为残差网络。

    残差网络的实现

    官方实现

    • Torch官方实现了ResNet网络,该实现在torchvision.models模块中。
    from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
    from torchsummary import summary
    net = resnet18()
    # print(net)
    # print("=========================================================")
    # 输出网络结构
    print(summary(net,input_size=(3, 224, 224), device='cpu'))
    # print("=========================================================")
    # # 输出网络结构
    # print(summary(net.cuda(),input_size=(3, 224, 224)))
    
    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1         [-1, 64, 112, 112]           9,408
           BatchNorm2d-2         [-1, 64, 112, 112]             128
                  ReLU-3         [-1, 64, 112, 112]               0
             MaxPool2d-4           [-1, 64, 56, 56]               0
                Conv2d-5           [-1, 64, 56, 56]          36,864
           BatchNorm2d-6           [-1, 64, 56, 56]             128
                  ReLU-7           [-1, 64, 56, 56]               0
                Conv2d-8           [-1, 64, 56, 56]          36,864
           BatchNorm2d-9           [-1, 64, 56, 56]             128
                 ReLU-10           [-1, 64, 56, 56]               0
           BasicBlock-11           [-1, 64, 56, 56]               0
               Conv2d-12           [-1, 64, 56, 56]          36,864
          BatchNorm2d-13           [-1, 64, 56, 56]             128
                 ReLU-14           [-1, 64, 56, 56]               0
               Conv2d-15           [-1, 64, 56, 56]          36,864
          BatchNorm2d-16           [-1, 64, 56, 56]             128
                 ReLU-17           [-1, 64, 56, 56]               0
           BasicBlock-18           [-1, 64, 56, 56]               0
               Conv2d-19          [-1, 128, 28, 28]          73,728
          BatchNorm2d-20          [-1, 128, 28, 28]             256
                 ReLU-21          [-1, 128, 28, 28]               0
               Conv2d-22          [-1, 128, 28, 28]         147,456
          BatchNorm2d-23          [-1, 128, 28, 28]             256
               Conv2d-24          [-1, 128, 28, 28]           8,192
          BatchNorm2d-25          [-1, 128, 28, 28]             256
                 ReLU-26          [-1, 128, 28, 28]               0
           BasicBlock-27          [-1, 128, 28, 28]               0
               Conv2d-28          [-1, 128, 28, 28]         147,456
          BatchNorm2d-29          [-1, 128, 28, 28]             256
                 ReLU-30          [-1, 128, 28, 28]               0
               Conv2d-31          [-1, 128, 28, 28]         147,456
          BatchNorm2d-32          [-1, 128, 28, 28]             256
                 ReLU-33          [-1, 128, 28, 28]               0
           BasicBlock-34          [-1, 128, 28, 28]               0
               Conv2d-35          [-1, 256, 14, 14]         294,912
          BatchNorm2d-36          [-1, 256, 14, 14]             512
                 ReLU-37          [-1, 256, 14, 14]               0
               Conv2d-38          [-1, 256, 14, 14]         589,824
          BatchNorm2d-39          [-1, 256, 14, 14]             512
               Conv2d-40          [-1, 256, 14, 14]          32,768
          BatchNorm2d-41          [-1, 256, 14, 14]             512
                 ReLU-42          [-1, 256, 14, 14]               0
           BasicBlock-43          [-1, 256, 14, 14]               0
               Conv2d-44          [-1, 256, 14, 14]         589,824
          BatchNorm2d-45          [-1, 256, 14, 14]             512
                 ReLU-46          [-1, 256, 14, 14]               0
               Conv2d-47          [-1, 256, 14, 14]         589,824
          BatchNorm2d-48          [-1, 256, 14, 14]             512
                 ReLU-49          [-1, 256, 14, 14]               0
           BasicBlock-50          [-1, 256, 14, 14]               0
               Conv2d-51            [-1, 512, 7, 7]       1,179,648
          BatchNorm2d-52            [-1, 512, 7, 7]           1,024
                 ReLU-53            [-1, 512, 7, 7]               0
               Conv2d-54            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-55            [-1, 512, 7, 7]           1,024
               Conv2d-56            [-1, 512, 7, 7]         131,072
          BatchNorm2d-57            [-1, 512, 7, 7]           1,024
                 ReLU-58            [-1, 512, 7, 7]               0
           BasicBlock-59            [-1, 512, 7, 7]               0
               Conv2d-60            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-61            [-1, 512, 7, 7]           1,024
                 ReLU-62            [-1, 512, 7, 7]               0
               Conv2d-63            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-64            [-1, 512, 7, 7]           1,024
                 ReLU-65            [-1, 512, 7, 7]               0
           BasicBlock-66            [-1, 512, 7, 7]               0
    AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
               Linear-68                 [-1, 1000]         513,000
    ================================================================
    Total params: 11,689,512
    Trainable params: 11,689,512
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.57
    Forward/backward pass size (MB): 62.79
    Params size (MB): 44.59
    Estimated Total Size (MB): 107.96
    ----------------------------------------------------------------
    None
    

    手工实现

    • 依然使用18层为例子

    • 实现思路:

      • 按照残差块->残差层->残差网络来实现;
        • 残差层由残差块构成(第一个块降维使用)
        • 残差一般第一个会降维。
      • 然后使用残差块堆砌成残差网络。
    • 实现施工图

    18层残差网络实现细节示意图
    • 下采样主要降维:
      • 降维方式很多,我们这里采用卷积+BN(可以增加BN也可以省略)

    残差块实现

    from torch.nn import Module, Conv2d, BatchNorm2d, MaxPool2d, ReLU, Sequential, AdaptiveAvgPool2d, Linear
    
    # 残差块
    class ResBlock(Module):
        def __init__(self, input_channels, ouput_channels, stride=1, downsample=None):
            super(ResBlock, self).__init__()
            # 第一个卷积
            self.conv_1 = Conv2d(in_channels=input_channels, out_channels=ouput_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn_1 = BatchNorm2d(ouput_channels)
            
            # 第一个激活函数
            self.relu_1 = ReLU(inplace=True)
            
            # 第二个卷积
            self.conv_2 = Conv2d(in_channels=ouput_channels,out_channels=ouput_channels,kernel_size=3,stride=1, padding=1, bias=False)
            self.bn_2 = BatchNorm2d(ouput_channels)
            
            # 是否需要采样
            self.downsample = downsample
            
            # 第二个激活函数
            self.relu_2 = ReLU(inplace=True)
            
        def forward(self, x):
            identity = x     # 这个值在 x + F(x)中根据F(x)的形状需要下采样,这样才能对齐
            
            # 第一层计算
            y = self.conv_1(x)
            y = self.bn_1(y)
            y = self.relu_1(y)
            
            # 第二层
            y = self.conv_2(y)
            y = self.bn_2(y)
            
            
            # 判定是否需要下采样
            if self.downsample:
                identity = self.downsample(x)
            
            y += identity     # x + F(x)残差形式输出
            y = self.relu_2(y)
            return y
            
    
    • 查看残差块的结构 - 无下采样
    from torchsummary import summary
    block = ResBlock(64, 64, stride=1)
    print(summary(block,input_size=(64, 56, 56), device='cpu'))
    
    
    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1           [-1, 64, 56, 56]          36,864
           BatchNorm2d-2           [-1, 64, 56, 56]             128
                  ReLU-3           [-1, 64, 56, 56]               0
                Conv2d-4           [-1, 64, 56, 56]          36,864
           BatchNorm2d-5           [-1, 64, 56, 56]             128
                  ReLU-6           [-1, 64, 56, 56]               0
    ================================================================
    Total params: 73,984
    Trainable params: 73,984
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.77
    Forward/backward pass size (MB): 9.19
    Params size (MB): 0.28
    Estimated Total Size (MB): 10.24
    ----------------------------------------------------------------
    None
    
    • 查看残差块的结构 - 下采样
      • 多出一个下采样层(使用的是卷积下采样)
    from torchsummary import summary
    from torch.nn import Module, Conv2d, BatchNorm2d, MaxPool2d, ReLU, Sequential, AdaptiveAvgPool2d, Linear
    
    downsample = Sequential( 
        Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), 
        BatchNorm2d(128)
    )
    
    
    block = ResBlock(64, 128, stride=2, downsample=downsample)
    print(summary(block,input_size=(64, 56, 56), device='cpu'))
    
    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1          [-1, 128, 28, 28]          73,728
           BatchNorm2d-2          [-1, 128, 28, 28]             256
                  ReLU-3          [-1, 128, 28, 28]               0
                Conv2d-4          [-1, 128, 28, 28]         147,456
           BatchNorm2d-5          [-1, 128, 28, 28]             256
                Conv2d-6          [-1, 128, 28, 28]          73,728
           BatchNorm2d-7          [-1, 128, 28, 28]             256
                  ReLU-8          [-1, 128, 28, 28]               0
    ================================================================
    Total params: 295,680
    Trainable params: 295,680
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.77
    Forward/backward pass size (MB): 6.12
    Params size (MB): 1.13
    Estimated Total Size (MB): 8.02
    ----------------------------------------------------------------
    None
    
    • 上面例子:
      • 输入: x = (channels = 64, dimension = 56 \times 56)
      • F(x) = (channels =128, 28 \times 28)
      • x必须下采样为 (channels =128, 28 \times 28)
        • 注意:这类通道也需要处理,所以使用卷积是比较好的选择。

    堆砌实现ResNet网络

    import torch
    from torch.nn import Module, Conv2d, BatchNorm2d, MaxPool2d, ReLU, Sequential, AdaptiveAvgPool2d, Linear
    
    from torch.nn import Module, Conv2d, BatchNorm2d, MaxPool2d, ReLU, Sequential, AdaptiveAvgPool2d, Linear
    
    # 残差块
    class YQResBlock(Module):
        def __init__(self, input_channels, ouput_channels, stride=1, downsample=None):
            super(YQResBlock, self).__init__()
            # 第一个卷积
            self.conv_1 = Conv2d(in_channels=input_channels, out_channels=ouput_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn_1 = BatchNorm2d(ouput_channels)
            
            # 第一个激活函数
            self.relu_1 = ReLU(inplace=True)
            
            # 第二个卷积
            self.conv_2 = Conv2d(in_channels=ouput_channels,out_channels=ouput_channels,kernel_size=3,stride=1, padding=1, bias=False)
            self.bn_2 = BatchNorm2d(ouput_channels)
            
            # 是否需要采样
            self.downsample = downsample
            
            # 第二个激活函数
            self.relu_2 = ReLU(inplace=True)
            
        def forward(self, x):
            identity = x     # 这个值在 x + F(x)中根据F(x)的形状需要下采样,这样才能对齐
            
            # 第一层计算
            y = self.conv_1(x)
            y = self.bn_1(y)
            y = self.relu_1(y)
            
            # 第二层
            y = self.conv_2(y)
            y = self.bn_2(y)
            
            
            # 判定是否需要下采样
            if self.downsample:
                identity = self.downsample(identity)
            
            y += identity     # x + F(x)残差形式输出
            y = self.relu_2(y)
            return y
            
            
    class YQResNet(Module):
        """
            cls_num:是分类的种类,输出的分类向量长度。
        """
        def __init__(self, cls_num=1000):
            super(YQResNet, self).__init__()
            # 一、ResNet的头(一个卷积 + 池化)
            in_channels = 3
            self.header_conv = Conv2d(in_channels, 64, kernel_size=7,  stride=2, padding=3, bias=False)   # 输入图像通道3,输出的64看ResNet的18层设计
            self.header_bn = BatchNorm2d(64)
            self.header_relu = ReLU(inplace=True)
            
            self.header_pool = MaxPool2d(kernel_size=3, stride=2, padding=1)  # padding补齐一般是kernel的一半。
            # 二、ResNet的残差层(4个残差层 = 4 * 2 * 2 = 16卷积层)
            # 2.1 残差层
            self.res_layer_1_1 = YQResBlock(64, 64, stride=1)
            self.res_layer_1_2 = YQResBlock(64, 64, stride=1)
            
            # 2.2 残差层
            downsample_1 = Sequential(    # 不能定义为成员,否则参数会被跟踪
                Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), 
                BatchNorm2d(128)
            )
            self.res_layer_2_1 = YQResBlock(64, 128, stride=2, downsample=downsample_1)
            self.res_layer_2_2 = YQResBlock(128, 128, stride=1)
            
            # 2.3 残差层
            downsample_2 = Sequential( 
                Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), 
                BatchNorm2d(256)
            )
            self.res_layer_3_1 = YQResBlock(128, 256, stride=2, downsample=downsample_2)
            self.res_layer_3_2 = YQResBlock(256, 256, stride=1)
            
            # 2.4 残差层
            downsample_3 = Sequential( 
                Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), 
                BatchNorm2d(512)
            )
            self.res_layer_4_1 = YQResBlock(256, 512, stride=2, downsample=downsample_3)
            self.res_layer_4_2 = YQResBlock(512, 512, stride=1)
            
            # 三、ResNet的尾(一个池化 + 全连接)
            self.footer_pool = AdaptiveAvgPool2d((1, 1))   # 采用自适应池化
            self.classifier = Linear(512 , cls_num)
            
        def forward(self, x):
            y = self.header_conv(x)
            y = self.header_bn(y)
            y = self.header_relu(y)
            y = self.header_pool(y)
            
            y = self.res_layer_1_1(y)
            y = self.res_layer_1_2(y)
            
            y = self.res_layer_2_1(y)
            y = self.res_layer_2_2(y)
            
            y = self.res_layer_3_1(y)
            y = self.res_layer_3_2(y)
            
            y = self.res_layer_4_1(y)
            y = self.res_layer_4_2(y)
            
            y = self.footer_pool(y)
            # 格式化
            y = torch.flatten(y, 1)
            y = self.classifier(y)
            return y
    
    • 查看我们构建的残差网络YQResNet
    from torchsummary import summary
    from torch.nn import Module, Conv2d, BatchNorm2d, MaxPool2d, ReLU, Sequential, AdaptiveAvgPool2d, Linear
    
    
    yq_net = YQResNet()
    print(summary(yq_net,input_size=(3, 224, 224), device='cpu'))
    
    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1         [-1, 64, 112, 112]           9,408
           BatchNorm2d-2         [-1, 64, 112, 112]             128
                  ReLU-3         [-1, 64, 112, 112]               0
             MaxPool2d-4           [-1, 64, 56, 56]               0
                Conv2d-5           [-1, 64, 56, 56]          36,864
           BatchNorm2d-6           [-1, 64, 56, 56]             128
                  ReLU-7           [-1, 64, 56, 56]               0
                Conv2d-8           [-1, 64, 56, 56]          36,864
           BatchNorm2d-9           [-1, 64, 56, 56]             128
                 ReLU-10           [-1, 64, 56, 56]               0
           YQResBlock-11           [-1, 64, 56, 56]               0
               Conv2d-12           [-1, 64, 56, 56]          36,864
          BatchNorm2d-13           [-1, 64, 56, 56]             128
                 ReLU-14           [-1, 64, 56, 56]               0
               Conv2d-15           [-1, 64, 56, 56]          36,864
          BatchNorm2d-16           [-1, 64, 56, 56]             128
                 ReLU-17           [-1, 64, 56, 56]               0
           YQResBlock-18           [-1, 64, 56, 56]               0
               Conv2d-19          [-1, 128, 28, 28]          73,728
          BatchNorm2d-20          [-1, 128, 28, 28]             256
                 ReLU-21          [-1, 128, 28, 28]               0
               Conv2d-22          [-1, 128, 28, 28]         147,456
          BatchNorm2d-23          [-1, 128, 28, 28]             256
               Conv2d-24          [-1, 128, 28, 28]          73,728
          BatchNorm2d-25          [-1, 128, 28, 28]             256
                 ReLU-26          [-1, 128, 28, 28]               0
           YQResBlock-27          [-1, 128, 28, 28]               0
               Conv2d-28          [-1, 128, 28, 28]         147,456
          BatchNorm2d-29          [-1, 128, 28, 28]             256
                 ReLU-30          [-1, 128, 28, 28]               0
               Conv2d-31          [-1, 128, 28, 28]         147,456
          BatchNorm2d-32          [-1, 128, 28, 28]             256
                 ReLU-33          [-1, 128, 28, 28]               0
           YQResBlock-34          [-1, 128, 28, 28]               0
               Conv2d-35          [-1, 256, 14, 14]         294,912
          BatchNorm2d-36          [-1, 256, 14, 14]             512
                 ReLU-37          [-1, 256, 14, 14]               0
               Conv2d-38          [-1, 256, 14, 14]         589,824
          BatchNorm2d-39          [-1, 256, 14, 14]             512
               Conv2d-40          [-1, 256, 14, 14]         294,912
          BatchNorm2d-41          [-1, 256, 14, 14]             512
                 ReLU-42          [-1, 256, 14, 14]               0
           YQResBlock-43          [-1, 256, 14, 14]               0
               Conv2d-44          [-1, 256, 14, 14]         589,824
          BatchNorm2d-45          [-1, 256, 14, 14]             512
                 ReLU-46          [-1, 256, 14, 14]               0
               Conv2d-47          [-1, 256, 14, 14]         589,824
          BatchNorm2d-48          [-1, 256, 14, 14]             512
                 ReLU-49          [-1, 256, 14, 14]               0
           YQResBlock-50          [-1, 256, 14, 14]               0
               Conv2d-51            [-1, 512, 7, 7]       1,179,648
          BatchNorm2d-52            [-1, 512, 7, 7]           1,024
                 ReLU-53            [-1, 512, 7, 7]               0
               Conv2d-54            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-55            [-1, 512, 7, 7]           1,024
               Conv2d-56            [-1, 512, 7, 7]       1,179,648
          BatchNorm2d-57            [-1, 512, 7, 7]           1,024
                 ReLU-58            [-1, 512, 7, 7]               0
           YQResBlock-59            [-1, 512, 7, 7]               0
               Conv2d-60            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-61            [-1, 512, 7, 7]           1,024
                 ReLU-62            [-1, 512, 7, 7]               0
               Conv2d-63            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-64            [-1, 512, 7, 7]           1,024
                 ReLU-65            [-1, 512, 7, 7]               0
           YQResBlock-66            [-1, 512, 7, 7]               0
    AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
               Linear-68                 [-1, 1000]         513,000
    ================================================================
    Total params: 13,065,768
    Trainable params: 13,065,768
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.57
    Forward/backward pass size (MB): 62.79
    Params size (MB): 49.84
    Estimated Total Size (MB): 113.21
    ----------------------------------------------------------------
    None
    

    训练残差网络

    • 使用ImageNet数据集,取其中4个目录(4类别)训练

    加载数据集的实现

    from torchvision.datasets import ImageFolder
    from torchvision.transforms import *
    from torchvision.transforms.functional import *
    from torch.utils.data import random_split
    from torch.utils.data import DataLoader
    
    # 加载指定目录下的图像,返回根据切分比例形成的数据加载器
    def load_data(img_dir, rate=0.8):
        transform = Compose(
            [
                Resize((224, 224)),          #RandomResizedCrop(227),
        #         RandomHorizontalFlip(),
                ToTensor(),
                Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),   # 均值与方差,(这个运算输入必须是Tensor图像,所以需要在ToTensor后)
            ]
        )
        ds = ImageFolder(img_dir, transform=transform)
    
        l = len(ds)
        l_train = int(l * rate)
        train, test = random_split(ds, [l_train, l - l_train])
        
        train_loader = torch.utils.data.DataLoader(dataset=train, shuffle=True, batch_size=200)   # 100,因为每个类的图像是1300个
        test_loader = torch.utils.data.DataLoader(dataset=test, shuffle=True, batch_size=200)  # 一个批次直接预测
    
        return train_loader, test_loader
    

    训练

    # from alexnet import AlexNet
    # from dataset import load_data
    import torch
    import torch.utils.data as Data
    import torchvision
    import numpy as np
    import struct
    import cv2
    
    # 1. 加载数据集
    print("1. 加载数据集")
    train_loader, test_loader = load_data("./imagenet2012", 0.8)
    
    CUDA = torch.cuda.is_available()
    # 2. 网络搭建
    print("2. 网络搭建")
    net=YQResNet(4)
    if CUDA:
        net.cuda()
    
    # 3. 训练
    print("3. 训练")
    optimizer=torch.optim.Adam(net.parameters(),lr=0.001)
    loss_F=torch.nn.CrossEntropyLoss()
    
    epoch = 50
    
    
    for n in range(epoch): # 数据集只迭代一次
        for step, input_data in enumerate(train_loader):
            x_, y_=input_data
            if CUDA:
                # GPU运算 -----------------------------------------------
                x_ = x_.cuda()
                y_ = y_.cuda()
            pred=net(x_.view(-1, 3, 224, 224))  
            loss=loss_F(pred, y_) # 计算loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        with torch.no_grad():
            all_num = 0.0 
            acc = 0.0
            for t_x, t_y in  test_loader:
                all_num  += len(t_y)
                if CUDA:
                    t_x = t_x.cuda()
                    t_y = t_y.cuda()
                test_pred=net(t_x.view(-1, 3, 224, 224))
                prob=torch.nn.functional.softmax(test_pred, dim=1)
                pred_cls=torch.argmax(prob, dim=1)
                acc += (pred_cls == t_y).float().sum()
            # print(f"轮数/批次:{n:02d}/{step:02d}: \t识别正确率:{acc/all_num *100:6.4f}, 损失值:{loss:6.4f}")
            print(f"轮数:{n+1:02d}: \t识别正确率:{acc/all_num *100:6.4f}, \t损失值:{loss:6.4f}")
    # 保存模型
    torch.save(net.state_dict(), "./resnet.models")  # GPU保存
    
    
    1. 加载数据集
    2. 网络搭建
    3. 训练
    轮数:01:  识别正确率:64.7299,  损失值:0.9878
    轮数:02:  识别正确率:70.0306,  损失值:0.7874
    轮数:03:  识别正确率:71.8654,  损失值:0.7470
    轮数:04:  识别正确率:71.7635,  损失值:0.6195
    轮数:05:  识别正确率:75.0255,  损失值:0.6669
    轮数:06:  识别正确率:73.5984,  损失值:0.6693
    ......
    轮数:45:  识别正确率:83.5882,  损失值:0.0246
    轮数:46:  识别正确率:84.3017,  损失值:0.0459
    轮数:47:  识别正确率:85.0153,  损失值:0.0206
    轮数:48:  识别正确率:84.5056,  损失值:0.0214
    轮数:49:  识别正确率:86.5443,  损失值:0.0038
    轮数:50:  识别正确率:85.2192,  损失值:0.0014
    

    附录:

    • 下面两个问题的基本上是显而易见的,第一个可以做数学推导求证:
      1. 为什么残差网络可以解决梯度消失问题?
      2. 为什么残差网络可以解决网络退化问题?

    相关文章

      网友评论

          本文标题:TORCH03-06ResNet网络

          本文链接:https://www.haomeiwen.com/subject/ykkqvhtx.html