美文网首页
Pytorch Application

Pytorch Application

作者: SnorlaxSE | 来源:发表于2019-10-06 17:35 被阅读0次

Pytorch 模型的网络结构可视化

参考: https://blog.csdn.net/TTdreamloong/article/details/83107110

可能遇到的报错:

  • AttributeError: 'torch._C.Value' object has no attribute 'debugName'
    原因: torch版本有误
    # 列举pip当前可以更新的所有安装包
    pip list --outdated
    # 使用Pip更新Pytorch和torchvision
    pip install --upgrade pytorch torchvision
    
  • 提示安装"Graphviz"
    原因: 需要在系统环境下安装,并非虚拟环境
    sudo apt-get install graphviz
    

中间特征可视化

参考: PyTorch | 提取神经网络中间层特征进行可视化

    def get_feature(self):
        input=self.process_image()
        print(input.shape)  
        x=input
        for index,layer in enumerate(self.pretrained_model):
            x=layer(x)
            if (index == self.selected_layer):
                return x

参考: pytorch模型中间层特征的提取

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers
 
    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
        # 目前不展示全连接层
            if "fc" in name: 
                x = x.view(x.size(0), -1)
            print(module)
            x = module(x)
            print(name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs

综上, 完整代码 ⤵️

import cv2
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import torch.nn as nn
from PIL import Image


simple_transform = transforms.Compose([transforms.Resize((224, 224)),
                                       transforms.ToTensor(),  # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
                                       ])


# 中间特征提取
class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layer):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layer

    def forward(self, x):
        for name, module in self.submodule._modules.items():
            if name is "fc":
                x = x.view(x.size(0), -1)
            x = module(x)
            print("moudle_name", name)
            if name in self.extracted_layers:
                return x


class VisualSingleFeature():
    def __init__(self, extract_features, save_path):
        self.extract_features = extract_features
        self.save_path = save_path

    def get_single_feature(self):
        print(self.extract_features.shape)  # ex. torch.Size([1, 128, 112, 112])

        extract_feature = self.extract_features[:, 0, :, :]
        print(extract_feature.shape)  # ex. torch.Size([1, 112, 112])

        extract_feature = extract_feature.view(extract_feature.shape[1], extract_feature.shape[2])
        print(extract_feature.shape)  # ex. torch.Size([112, 112])

        return extract_feature

    def save_feature_to_img(self):
        # to numpy
        extract_feature = self.get_single_feature().data.numpy()
        # use sigmod to [0,1]
        extract_feature = 1.0/(1+np.exp(-1*extract_feature))
        # to [0,255]
        extract_feature = np.round(extract_feature*255)
        print(extract_feature[0])
        # save image
        cv2.imwrite(self.save_path, extract_feature)


def single_image_sample():
    img_path = './snorlax.png'
    input_img = Image.open(img_path).convert('RGB')  # 读取图像
    input_tensor = simple_transform(input_img)
    print(input_tensor.shape)  # torch.Size([3, 224, 224])
    x = input_tensor[np.newaxis, :, :, :]
    print(x.shape)  # torch.Size([1, 3, 224, 224])
    return x
snorlax.png
    # test VGG16
    x = single_image_sample()
    for target_layer in range(0, 31):
        pretrained_module = models.vgg16(pretrained=True).features
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_layer))
        target_features = myexactor(x)
        savepath = './VGG16/layer_{}.jpg'.format(target_layer)  # 需手动创建文件夹`./VGG16`
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 sequential
    x = single_image_sample()
    for target_sequential in ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']:
        pretrained_module = models.resnet50(pretrained=True)
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
        target_features = myexactor(x)
        savepath = './Resnet50/{}.jpg'.format(target_sequential)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 layer1
    x = single_image_sample()
    pretrained_model = models.resnet50(pretrained=True)
    pretrained_module = pretrained_model.layer1

    for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
        pre_module = getattr(pretrained_model, pre_sequential)
        x = pre_module(x)

    for target_Bottleneck_index in range(0, 3):
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=str(target_Bottleneck_index))
        target_features = myexactor(x)
        savepath = './Resnet50/layer1/{}.jpg'.format(target_Bottleneck_index)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()
    # test resnet50 layer1 Bottleneck0
    x = single_image_sample()
    pretrained_model = models.resnet50(pretrained=True)
    pretrained_module = pretrained_model.layer1._modules['0']

    for pre_sequential in ['conv1', 'bn1', 'relu', 'maxpool']:
        pre_module = getattr(pretrained_model, pre_sequential)
        x = pre_module(x)

    for target_sequential in ['conv1', 'bn1', 'conv2', 'bn2', 'conv3', 'bn3', 'relu']:
        myexactor = FeatureExtractor(submodule=pretrained_module, extracted_layer=target_sequential)
        target_features = myexactor(x)
        savepath = './Resnet50/layer1/Bottleneck0/{}.jpg'.format(target_sequential)
        VisualSingleFeature(target_features, savepath).save_feature_to_img()

优化:visual all feature outputs, not only single feature 可参考:基于Pytorch的特征图提取
 # 特征输出可视化
    for i in range(feature_channel_number):
        ax = plt.subplot(6, 6, i + 1)
        ax.set_title('Feature {}'.format(i))
        ax.axis('off')
        plt.imshow(target_features.data.numpy()[0,i,:,:],cmap='jet')

    plt.show()

相关文章

网友评论

      本文标题:Pytorch Application

      本文链接:https://www.haomeiwen.com/subject/iugupctx.html