定义网络
resnet101 = torchvision.models.resnet.ResNet(torchvision.models.resnet.Bottleneck,[3, 4, 23, 3],1000)
res_conv1 = torch.nn.Sequential(resnet101.conv1)
res_conv1_maxpool = torch.nn.Sequential(resnet101.conv1,resnet101.maxpool)
res_layer1 = torch.nn.Sequential(resnet101.conv1,resnet101.maxpool,resnet101.layer1)
res_layer2 = torch.nn.Sequential(resnet101.conv1,resnet101.maxpool,resnet101.layer1,resnet101.layer2)
res_layer3 = torch.nn.Sequential(resnet101.conv1,resnet101.maxpool,resnet101.layer1,resnet101.layer2,resnet101.layer3)
res_layer4 = torch.nn.Sequential(resnet101.conv1,resnet101.maxpool,resnet101.layer1,resnet101.layer2,resnet101.layer3,resnet101.layer4)
res_avgpool = torch.nn.Sequential(resnet101.conv1,resnet101.maxpool,resnet101.layer1,resnet101.layer2,resnet101.layer3,resnet101.layer4,resnet101.avgpool)
输入图片
img = torch.randn(size=(2,3,224,224))
查看形状
In [51]: img.shape
Out[51]: torch.Size([2, 3, 224, 224])
In [52]: res_conv1(img).shape
Out[52]: torch.Size([2, 64, 112, 112])
In [53]: res_conv1_maxpool(img).shape
Out[53]: torch.Size([2, 64, 56, 56])
In [54]: res_layer1(img).shape
Out[54]: torch.Size([2, 256, 56, 56])
In [55]: res_layer2(img).shape
Out[55]: torch.Size([2, 512, 28, 28])
In [56]: res_layer3(img).shape
Out[56]: torch.Size([2, 1024, 14, 14])
In [57]: res_layer4(img).shape
Out[57]: torch.Size([2, 2048, 7, 7])
In [58]: res_avgpool(img).shape
Out[58]: torch.Size([2, 2048, 1, 1])
In [59]: resnet101(img).shape
Out[59]: torch.Size([2, 1000])
网友评论