美文网首页
Pytorch搭建VGG Net

Pytorch搭建VGG Net

作者: Yankee_13 | 来源:发表于2018-12-19 17:42 被阅读0次

    献给莹莹

    1. VGG Net网络结构

    VGG是十分经典的网络了,没什么好说的。网络结构如下

    vgg结构
    注解:

    2.搭建过程

    • 1.加载必要及准备工作
    import torch
    import torch.nn as nn
    
    cfg = {
        'VGG11': [64, 'M', 128, 'M', 256,'M', 512, 'M', 512,'M'],
        'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
        'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
        'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
    }
    # 不同的vgg结构,这样写可以有效节约代码空间。
    
    • 2.构建模型
    class VGG(nn.Module):
    #nn.Module是一个特殊的nn模块,加载nn.Module,这是为了继承父类
        def __init__(self, vgg_name):
            super(VGG, self).__init__()
            #super 加载父类中的__init__()函数
            self.features = self._make_layers(cfg[vgg_name])
            self.classifier=nn.Linear(512,10)
            #该网络输入为Cifar10数据集,因此输出为(512,1,1)
    
        def forward(self, x):
            out = self.features(x)
            out = out.view(out.size(0), -1)
            #这一步将out拉成out.size(0)的一维向量
            out = self.classifier(out)
            
            return out
    
        def _make_layers(self, cfg):
            layers = []
            in_channels = 3
            for x in cfg:
                if x == 'M':
                    layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                else:
                    layers += [nn.Conv2d(in_channels, x, kernel_size=3, 
                               padding=1,bias=False),
                               nn.BatchNorm2d(x),
                               nn.ReLU(inplace=True)]
                    in_channels = x
            return nn.Sequential(*layers)
     '''       
    nn.Sequential(*layers) 表示(只是举例子)
    Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace)
    )
    '''
    
    • 3.检验模型
    def t():
      net = VGG('VGG19')
      x = torch.randn(5,3,32,32)
      y = net(x)
      print(y.size())
    if __name__ == "__main__":
      t()
    #如果输出为(5,10),表示结果正确
    

    相关文章

      网友评论

          本文标题:Pytorch搭建VGG Net

          本文链接:https://www.haomeiwen.com/subject/wxajkqtx.html