美文网首页一起折腾Pytorch
PyTorch | 教你用小妙招提取神经网络某一层特征

PyTorch | 教你用小妙招提取神经网络某一层特征

作者: 与阳光共进早餐 | 来源:发表于2018-08-16 16:05 被阅读0次

    一 写在前面

    未经允许,不得转载,谢谢。

    我们常常需要提取神经网络某一层得到的结果作为特征进行处理。

    example

    直观来讲,我们想提取最后一层fc前面层的输出作为特征,那么怎么样才能获取到呢?

    在我看到的大部分给出的教程中都使用了重复写模型、或者用hook钩子函数的做法,这样的做法所需要的代码量比较大,又不是非常简洁。

    二 解决方法

    我突然想到我们是不是可以获取模型中的变量,那么我们只需要将自己想要的变量保存下来即可。

    比如,在这个例子中我想获得fc前面的特征,那么我就在前面加一句self.feature=x,然后再需要的地方调用model.feature即可得到。

    这样就可以避免以上的麻烦了,简单省事。

    如下图所示:

    三 实验验证

    我用先取得feature值,再将它单独传入到fc层得到的结果与之间将数据传入模型得到的结过进行比较的方法来验证实验。

    • 验证代码:
        # load model and data
        model = P3D199()
        model = model.cuda()
        model.eval()
        data=torch.autograd.Variable(torch.rand(16,3,16,160,160)).cuda() 
        # verify
        out=model(data)
        feature=model.feature
        out2=model.fc(feature)
        print(out==out2)   
    

    这样提取得到的特征可以用于中间层可视化,当然也可以用于对特征层的输出做一些额外的处理~


    20190214更新帖~~~~~~~~

    1 问题描述

    有的时候我们的模型不是自己写的,而是直接调用pytorch封装好的torchvision.models里面的模型,那代码文件就没有修改的权限 。

    2 解决方法

    大概可以有两种解决方法:

    1. 模型不是很复杂的时候,直接拷贝源码整理成自己的model文件使用;
    2. 网上比较多的使用的是hook之类的教程,我个人觉得有点麻烦(懒癌==);
    3. 重写forward方法(不需要继承的情况下):
    • 以resnet18为例:
    model = models.resnet18(pretrained=False,num_classes=CIFAR10_num_classes)
    def my_forward(model, x):
        mo = nn.Sequential(*list(model.children())[:-1])
        feature = mo(x)
        feature = feature.view(x.size(0), -1)
        output= model.fc(feature)
        return feature, output
    
    • 实验验证:
        print (model)
        input = Variable(torch.rand(8,3,224,224)).cuda()
        model.cuda()
    
        # get output directly
        print (input.shape)     # [8,3,224,224]
        output = model(input)
        print (output.shape)    # [8,num_classes]
    
        # get feature and output in a new way
        myfeature,myoutput=my_forward(model,input)
        print(myfeature.shape,myoutput.shape)  # [8,512]  [8,num_classes]
        print(myoutput==output)   # equal
    

    相关文章

      网友评论

        本文标题:PyTorch | 教你用小妙招提取神经网络某一层特征

        本文链接:https://www.haomeiwen.com/subject/yglobftx.html