美文网首页
PyTorch 获取模型中间层方法

PyTorch 获取模型中间层方法

作者: 翻开日记 | 来源:发表于2022-07-21 16:13 被阅读0次

获取模型中间层

self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", nn.BatchNorm2d(num_init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )
self.features.add_module("denseblock%d" % (i + 1), block)
self.features.add_module("transition%d" % (i + 1), trans)
...

通过index获取

        x = torch.rand([1, 3, 320, 320])
        features = []
        for i in range(len(self.features)):
            x = self.features[i](x)
            if i == 2:
                features.append(x)

通过特征名获取

        x = torch.rand([1, 3, 320, 320])
        features = []
        for name, module in self.features._modules.items():
            x = module(x)
            if 'denseblock' in name:
                features.append(x)
                print(x.shape)

相关文章

网友评论

      本文标题:PyTorch 获取模型中间层方法

      本文链接:https://www.haomeiwen.com/subject/rogklrtx.html