美文网首页
# Python可视化resnet50所有层特征图

# Python可视化resnet50所有层特征图

作者: guanalex | 来源:发表于2019-04-23 16:50 被阅读0次

    Python可视化resnet50所有层特征图

    (转载https://blog.csdn.net/u012435142/article/details/84711978)

    2018年12月02日 14:18:04 未完城 阅读数:703

    <article class="baidu_pl" style="box-sizing: inherit; outline: 0px; margin: 0px; padding: 16px 0px 0px; display: block; position: relative; color: rgba(0, 0, 0, 0.75); font-family: -apple-system, "SF UI Text", Arial, "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "WenQuanYi Micro Hei", sans-serif; font-size: 14px; font-style: normal; font-variant-ligatures: common-ligatures; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u012435142/article/details/84711978

    Python可视化resnet50所有层特征图
    使用pytorch中预训练模型,在网络inference的过程中显示特征图的每个通道.

    文章目录

    代码

    import cv2
    import time
    import os
    import matplotlib.pyplot as plt
    import torch
    from torch import nn
    import torchvision.models as models
    import torchvision.transforms as transforms
    import numpy as np
    
    savepath='vis_resnet50/features_elephant'
    if not os.path.exists(savepath):
        os.mkdir(savepath)
    
    def draw_features(width,height,x,savename):
        tic=time.time()
        fig = plt.figure(figsize=(16, 16))
        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)
        for i in range(width*height):
            plt.subplot(height,width, i + 1)
            plt.axis('off')
            # plt.tight_layout()
            img = x[0, i, :, :]
            pmin = np.min(img)
            pmax = np.max(img)
            img = (img - pmin) / (pmax - pmin + 0.000001)
            plt.imshow(img, cmap='gray')
            print("{}/{}".format(i,width*height))
        fig.savefig(savename, dpi=100)
        fig.clf()
        plt.close()
        print("time:{}".format(time.time()-tic))
    
    class ft_net(nn.Module):
    
        def __init__(self):
            super(ft_net, self).__init__()
            model_ft = models.resnet50(pretrained=True)
            self.model = model_ft
    
        def forward(self, x):
            if True: # draw features or not
                x = self.model.conv1(x)
                draw_features(8,8,x.cpu().numpy(),"{}/f1_conv1.png".format(savepath))
    
                x = self.model.bn1(x)
                draw_features(8, 8, x.cpu().numpy(),"{}/f2_bn1.png".format(savepath))
    
                x = self.model.relu(x)
                draw_features(8, 8, x.cpu().numpy(), "{}/f3_relu.png".format(savepath))
    
                x = self.model.maxpool(x)
                draw_features(8, 8, x.cpu().numpy(), "{}/f4_maxpool.png".format(savepath))
    
                x = self.model.layer1(x)
                draw_features(16, 16, x.cpu().numpy(), "{}/f5_layer1.png".format(savepath))
    
                x = self.model.layer2(x)
                draw_features(16, 32, x.cpu().numpy(), "{}/f6_layer2.png".format(savepath))
    
                x = self.model.layer3(x)
                draw_features(32, 32, x.cpu().numpy(), "{}/f7_layer3.png".format(savepath))
    
                x = self.model.layer4(x)
                draw_features(32, 32, x.cpu().numpy()[:, 0:1024, :, :], "{}/f8_layer4_1.png".format(savepath))
                draw_features(32, 32, x.cpu().numpy()[:, 1024:2048, :, :], "{}/f8_layer4_2.png".format(savepath))
    
                x = self.model.avgpool(x)
                plt.plot(np.linspace(1, 2048, 2048), x.cpu().numpy()[0, :, 0, 0])
                plt.savefig("{}/f9_avgpool.png".format(savepath))
                plt.clf()
                plt.close()
    
                x = x.view(x.size(0), -1)
                x = self.model.fc(x)
                plt.plot(np.linspace(1, 1000, 1000), x.cpu().numpy()[0, :])
                plt.savefig("{}/f10_fc.png".format(savepath))
                plt.clf()
                plt.close()
            else :
                x = self.model.conv1(x)
                x = self.model.bn1(x)
                x = self.model.relu(x)
                x = self.model.maxpool(x)
                x = self.model.layer1(x)
                x = self.model.layer2(x)
                x = self.model.layer3(x)
                x = self.model.layer4(x)
                x = self.model.avgpool(x)
                x = x.view(x.size(0), -1)
                x = self.model.fc(x)
    
            return x
    
    model=ft_net().cuda()
    
    # pretrained_dict = resnet50.state_dict()
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # model_dict.update(pretrained_dict)
    # net.load_state_dict(model_dict)
    model.eval()
    img=cv2.imread('elephant.png')
    img=cv2.resize(img,(224,224));
    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img=transform(img).cuda()
    img=img.unsqueeze(0)
    with torch.no_grad():
        start=time.time()
        out=model(img)
        print("total time:{}".format(time.time()-start))
        result=out.cpu().numpy()
        # ind=np.argmax(out.cpu().numpy())
        ind=np.argsort(result,axis=1)
        for i in range(5):
            print("predict:top {} = cls {} : score {}".format(i+1,ind[0,1000-i-1],result[0,1000-i-1]))
        print("done")
    
    

    input image [1,3,224,224]

    在这里插入图片描述

    conv1 [1,64,112,112]

    在这里插入图片描述

    bn1_relu [1,64,112,112]

    在这里插入图片描述

    maxpool [1,64,56,56]

    在这里插入图片描述

    layer1 [1,256,56,56]

    在这里插入图片描述

    layer2 [1,512,28,28]

    在这里插入图片描述

    layer3 [1,1024,14,14]

    在这里插入图片描述

    layer4 [1,2048,7,7]

    在这里插入图片描述 在这里插入图片描述

    avgpool [1,2048]

    在这里插入图片描述

    fc [1,1000]

    在这里插入图片描述

    </article>

    相关文章

      网友评论

          本文标题:# Python可视化resnet50所有层特征图

          本文链接:https://www.haomeiwen.com/subject/bybpmqtx.html