美文网首页
PyTorch 获取模型中间层方法

PyTorch 获取模型中间层方法

作者: 翻开日记 | 来源:发表于2022-07-21 16:13 被阅读0次

    获取模型中间层

    self.features = nn.Sequential(
                OrderedDict(
                    [
                        ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                        ("norm0", nn.BatchNorm2d(num_init_features)),
                        ("relu0", nn.ReLU(inplace=True)),
                        ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                    ]
                )
            )
    self.features.add_module("denseblock%d" % (i + 1), block)
    self.features.add_module("transition%d" % (i + 1), trans)
    ...
    

    通过index获取

            x = torch.rand([1, 3, 320, 320])
            features = []
            for i in range(len(self.features)):
                x = self.features[i](x)
                if i == 2:
                    features.append(x)
    

    通过特征名获取

            x = torch.rand([1, 3, 320, 320])
            features = []
            for name, module in self.features._modules.items():
                x = module(x)
                if 'denseblock' in name:
                    features.append(x)
                    print(x.shape)
    

    相关文章

      网友评论

          本文标题:PyTorch 获取模型中间层方法

          本文链接:https://www.haomeiwen.com/subject/rogklrtx.html