classResnet50Extractor(nn.Module):def__init__(self, submodule, extracted_layer):super(Resnet50Extractor, self).__init__() self.submodule = submodule self.extracted_layer = extracted_layerdefforward(self, x):ifself.extracted_layer =='maxpool': modules = list(self.submodule.children())[:4]elifself.extracted_layer =='inner-layer-3': modules = list(self.submodule.children())[:6] third_module = list(self.submodule.children())[6] third_module_modules = list(third_module.children())[:3]# take the first three inner modulesthird_module = nn.Sequential(*third_module_modules) modules.append(third_module)elifself.extracted_layer =='layer-3': modules = list(self.submodule.children())[:7]else:# after avg-poolmodules = list(self.submodule.children())[:9] self.submodule = nn.Sequential(*modules) x = self.submodule(x)returnx
then call like below:
model_ft= models.resnet50(pretrained=True)extractor= Resnet50Extractor(model_ft, extracted_layer)features= extractor(input_tensor)
网友评论