Pytorch 模型的网络结构可视化
参考: https://blog.csdn.net/TTdreamloong/article/details/83107110
可能遇到的报错:
- AttributeError: 'torch._C.Value' object has no attribute 'debugName'
原因: torch版本有误# 列举pip当前可以更新的所有安装包 pip list --outdated # 使用Pip更新Pytorch和torchvision pip install --upgrade pytorch torchvision
- 提示安装"Graphviz"
原因: 需要在系统环境下安装,并非虚拟环境sudo apt-get install graphviz
中间特征可视化
参考: PyTorch | 提取神经网络中间层特征进行可视化
def get_feature(self):
input=self.process_image()
print(input.shape)
x=input
for index,layer in enumerate(self.pretrained_model):
x=layer(x)
if (index == self.selected_layer):
return x
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
# 目前不展示全连接层
if "fc" in name:
x = x.view(x.size(0), -1)
print(module)
x = module(x)
print(name)
if name in self.extracted_layers:
outputs.append(x)
return outputs
综上, 完整代码 ⤵️
import cv2
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import torch.nn as nn
from PIL import Image
simple_transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(), # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 中间特征提取
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layer):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layer
def forward(self, x):
for name, module in self.submodule._modules.items():
if name is "fc":
x = x.view(x.size(0), -1)
x = module(x)
print("moudle_name", name)
if name in self.extracted_layers:
return x
class VisualSingleFeature():
def __init__(self, extract_features, save_path):
self.extract_features = extract_features
self.save_path = save_path
def get_single_feature(self):
print(self.extract_features.shape) # ex. torch.Size([1, 128, 112, 112])
extract_feature = self.extract_features[:, 0, :, :]
print(extract_feature.shape) # ex. torch.Size([1, 112, 112])
extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
print(extract_feature.shape) # ex. torch.Size([112, 112])
return extract_feature
def save_feature_to_img(self):
# to numpy
extract_feature = self.get_single_feature().data.numpy()
# use sigmod to [0,1]
extract_feature = 1.0/(1+np.exp(-1*extract_feature))
# to [0,255]
extract_feature = np.round(extract_feature*255)
print(extract_feature[0])
# save image
cv2.imwrite(self.save_path, extract_feature)
def single_image_sample():
img_path = './snorlax.png'
input_img = Image.open(img_path).convert('RGB') # 读取图像
input_tensor = simple_transform(input_img)
print(input_tensor.shape) # torch.Size([3, 224, 224])
x = input_tensor[np.newaxis, :, :, :]
print(x.shape) # torch.Size([1, 3, 224, 224])
return x
![](https://img.haomeiwen.com/i5267500/24e3ff2273ee6a39.png)
# test VGG16
x = single_image_sample()
for target_layer in range(0, 31):
pretrained_module = models.vgg16(pretrained=True).features
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_layer))
target_features = myexactor(x)
savepath = './VGG16/layer_{}.jpg'.format(target_layer) # 需手动创建文件夹`./VGG16`
VisualSingleFeature(target_features, savepath).save_feature_to_img()
![](https://img.haomeiwen.com/i5267500/a13b38acde1dcddc.png)
# test resnet50 sequential
x = single_image_sample()
for target_sequential in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']:
pretrained_module = models.resnet50(pretrained=True)
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
target_features = myexactor(x)
savepath = './Resnet50/{}.jpg'.format(target_sequential)
VisualSingleFeature(target_features, savepath).save_feature_to_img()
![](https://img.haomeiwen.com/i5267500/b4ef081a4eff308f.png)
# test resnet50 layer1
x = single_image_sample()
pretrained_model = models.resnet50(pretrained=True)
pretrained_module = pretrained_model.layer1
for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
pre_module = getattr(pretrained_model, pre_sequential)
x = pre_module(x)
for target_Bottleneck_index in range(0, 3):
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_Bottleneck_index))
target_features = myexactor(x)
savepath = './Resnet50/layer1/{}.jpg'.format(target_Bottleneck_index)
VisualSingleFeature(target_features, savepath).save_feature_to_img()
![](https://img.haomeiwen.com/i5267500/908d4aaceb3cb3bf.png)
# test resnet50 layer1 Bottleneck0
x = single_image_sample()
pretrained_model = models.resnet50(pretrained=True)
pretrained_module = pretrained_model.layer1._modules['0']
for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
pre_module = getattr(pretrained_model, pre_sequential)
x = pre_module(x)
for target_sequential in ['conv1', 'bn1', 'conv2', 'bn2', 'conv3', 'bn3', 'relu']:
myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
target_features = myexactor(x)
savepath = './Resnet50/layer1/Bottleneck0/{}.jpg'.format(target_sequential)
VisualSingleFeature(target_features, savepath).save_feature_to_img()
![](https://img.haomeiwen.com/i5267500/7ebb3c881420cf58.png)
优化:
visual all feature outputs, not only single feature
可参考:基于Pytorch的特征图提取
# 特征输出可视化
for i in range(feature_channel_number):
ax = plt.subplot(6, 6, i + 1)
ax.set_title('Feature {}'.format(i))
ax.axis('off')
plt.imshow(target_features.data.numpy()[0,i,:,:],cmap='jet')
plt.show()
- hook 可参考:『PyTorch』第十六弹_hook技术、 Pytorch 可视化CNN中间层的输出 未测试
网友评论