美文网首页
torchvision.models

torchvision.models

作者: 小妖怪A | 来源:发表于2021-01-21 10:53 被阅读0次
class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

features和classifier中有一个自适应平均pooling层,将任意输入转换为固定输出,所以不用担心输入图片大小。

相关文章

网友评论

      本文标题:torchvision.models

      本文链接:https://www.haomeiwen.com/subject/kcprzktx.html