美文网首页
PyTorch如何确定全连接的参数

PyTorch如何确定全连接的参数

作者: geekboys | 来源:发表于2020-09-04 03:04 被阅读0次

    如何确定全连接的参数

    虽然目前使用全连接层的网络模型越来越少,但是仍有部分网络需要全连接层,但是如果通过CNN计算图片的输出尺寸可以说有点复杂。现在就使用PyTorch自带的功能来实现这个计算,可以说非常简单。首先,我们先定义如下的网络:

    class LinearDemo(nn.Module):
        def __init__(self):
            super(LinearDemo,self).__init__()
            self.conv=nn.Sequential(
                nn.Conv2d(3,96,kernel_size=11,stride=4),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3,stride=2),
    
                nn.Conv2d(96,256,kernel_size=5,padding=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3,stride=2),
    
                nn.Conv2d(256,384,kernel_size=3,padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(384,384,kernel_size=3,padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(384,256,kernel_size=3,padding=1),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3,stride=2)
    
            )
    

    上面代码中的基本组件这里就不多赘述了,下面正常书写全连接层如下:

    self.fc=nn.Sequential(
    #         nn.Linear(???,4096)
    #     )
    

    其中???就是我们需要计算的参数值,如果通过层的关系进行计算则很容易出错。这里推荐使用PyTorch自带的forward方法进行推算。我们写forward方法如下:

    def forward(self,x):
            x=self.conv(x)
            print(x.size())
    

    这里我们可以在main方法中进行调用后,就可以输出该参数。main方法如下:

    net=LinearDemo()
    data_input=torch.randn(1,3,80,280)
    print(data_input.size())
    net(data_input)
    

    这样就将上面的参数输出了。非常的简单

    相关文章

      网友评论

          本文标题:PyTorch如何确定全连接的参数

          本文链接:https://www.haomeiwen.com/subject/zmvasktx.html