美文网首页
pytorch 定义一个网络

pytorch 定义一个网络

作者: Zeke_Wang | 来源:发表于2018-12-12 19:03 被阅读0次

    声明一个关于网络的类

    import torch.nn as nn
    class NetName(nn.Module):
        def __init__(self):
            super(NetName, self).__init__()
    
            nn.module1 = ...
            nn.module2 = ...
            nn.module3 = ...
        
        def forward(self,x):
            x = self.module1(x)
            x = self.module1(x)
            x = self.module2(x)
            x = self.module3(x)
            return x
    

    其中在构造函数__init__中构造这个NN中需要使用的各种模块(module),比如:参数完全相同的maxpooling声明为一个模块,或者例如在CV任务中,把feature_extraction的网络和classification的网络分别声明。
    forward函数用于声明各个模块间的关系。即,连接整个网络。

    net = NetName().to(device) # 创建网络,并放入指定的device
    

    网络创建后,可以通过以下方式遍历模块信息:

    for name, module in net._modules.items():
        print(name) # name就是__init__中的各个模块名
        print(module) # module就是各个模块内具体的层
    

    示例:AlexNet

    注释中的tensor大小变化是基于cifar10的图片----(channel=3, height=32, width=32)

    import torch.nn as nn
    
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), # (3,32,32) -> (64,8,8)
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),                 # (64,8,8)  -> (64,4,4)
                nn.Conv2d(64, 192, kernel_size=5, padding=2),          # (64,4,4)  -> (192,4,4)
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),                 # (192,4,4) -> (192,2,2)
                nn.Conv2d(192, 384, kernel_size=3, padding=1),         # (192,2,2) -> (384,2,2)
                nn.ReLU(inplace=True),
                nn.Conv2d(384, 256, kernel_size=3, padding=1),         # (384,2,2) -> (256,2,2)
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),         # (256,2,2) -> (256,2,2)
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),                 # (256,2,2) -> (256,1,1)
            )
            
            self.classifier = nn.Linear(256, 10)                       # (batch_size,256) -> (batch_size,10)
            
    
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1) # flatten to (batch_size, 256*1*1)
            x = self.classifier(x)
            return x
    

    相关文章

      网友评论

          本文标题:pytorch 定义一个网络

          本文链接:https://www.haomeiwen.com/subject/wmvphqtx.html