获取模型中间层
self.features = nn.Sequential(
OrderedDict(
[
("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
("norm0", nn.BatchNorm2d(num_init_features)),
("relu0", nn.ReLU(inplace=True)),
("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]
)
)
self.features.add_module("denseblock%d" % (i + 1), block)
self.features.add_module("transition%d" % (i + 1), trans)
...
通过index
获取
x = torch.rand([1, 3, 320, 320])
features = []
for i in range(len(self.features)):
x = self.features[i](x)
if i == 2:
features.append(x)
通过特征名获取
x = torch.rand([1, 3, 320, 320])
features = []
for name, module in self.features._modules.items():
x = module(x)
if 'denseblock' in name:
features.append(x)
print(x.shape)
网友评论