美文网首页
(25)模型的可解释性

(25)模型的可解释性

作者: 顽皮的石头7788121 | 来源:发表于2019-02-26 14:34 被阅读0次

    模型的可解释性问题意在解决模型是通过哪些像素点决定了最终的分类类别。常见的可视化热图生成工具为CAM(Class Activation Mapping)。

    对一个深层的卷积神经网络而言,通过多次卷积和池化以后,它的最后一层卷积层包含了最丰富的空间和语义信息,再往下就是全连接层和softmax层了,其中所包含的信息都是人类难以理解的,很难以可视化的方式展示出来。所以说,要让卷积神经网络的对其分类结果给出一个合理解释,必须要充分利用好最后一个卷积层。

    通过最后一个卷积层的权重与提取到的各个通道特征的点积,并上采样到原始图像大小,再与原始图像结合,就可以生成相应的图像。其代码如下:


#可视化代码

def visual(**kwargs):

    opt.parse(kwargs)

    model = models.densenet121()

    model.classifier = torch.nn.Linear(1024,2);

    checkpoint = torch.load('/home/hdc/yfq/CAG/checkpoints/Densenet1210219_19:09:46.pth')

    model_dict = model.state_dict()

    state_dict = {k: vfor k, vin checkpoint.items()if kin model_dictand "classifier" not in k}

    model.load_state_dict(state_dict,False)

    fc_weight = checkpoint['module.classifier.weight']

    normalize = T.Normalize(    

        mean=[0.485,0.456,0.406],

        std=[0.229,0.224,0.225])

   transforms0 = T.Compose([

        T.RandomResizedCrop(224)

    ])

    transforms1 = T.Compose([

        T.ToTensor(),

        normalize

    ])

# data

    img_path ='/home/hdc/yfq/CAG/data/visual/Yes.478.bmp'

    data = Image.open(img_path)

    data0 = transforms0(data)

    data1 = transforms1(data0)

    data1 = data1.unsqueeze(0)

    model.eval()

    score,feature = model(data1)

    CAMs = returnCAM(feature,fc_weight)

    _,_,height, width = data1.size()

    heatmap = cv2.applyColorMap(cv2.resize(CAMs[1], (width, height)), cv2.COLORMAP_JET)

    result = heatmap *0.3 + np.array(data0) *0.5

     cv2.imwrite('/home/hdc/yfq/CAG/data/visual/Yes.478.CAM0.bmp', result)

def returnCAM(feature_conv, weight_softmax, class_idx =2):

# generate the class activation maps upsample to 256x256

    size_upsample = (224,224)

    bz, nc, h, w = feature_conv.shape

    output_cam = []

    weight_softmax = weight_softmax.unsqueeze(1)

    for idxin range(class_idx):

        cam = weight_softmax[idx].cpu().mm(feature_conv.reshape((nc, h*w)))

        cam = cam.reshape(h, w).detach().numpy()

        cam = cam - np.min(cam)

        cam_img = cam / np.max(cam)

        cam_img = np.uint8(255 * cam_img)

        output_cam.append(cv2.resize(cam_img, size_upsample))

    return output_cam

if __name__=='__main__':

    visual();


 1:

相关文章

网友评论

      本文标题:(25)模型的可解释性

      本文链接:https://www.haomeiwen.com/subject/utpayqtx.html