美文网首页
如何使用pytorch torchvision.models中的

如何使用pytorch torchvision.models中的

作者: 一位学有余力的同学 | 来源:发表于2020-05-07 20:08 被阅读0次

pytorch中的torchvision.models中包含了多种预训练模型: VGG、Resnet、Googlenet等。然而这些预训练模型的输出分类可能和我们的有差别,所以我们要对预训练模型做出适量修改。

1. 以resnet18为例

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, 2)

我们先将resnet18模型加载出来,预训练选择True,新模型的名字命名为model_ft,但是不清楚fc.in_features是什么意思,我们不妨将model_ft输出,得到如下结果:

  ...
  (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

为了看起来简洁这里只截取了模型的最后两层,想知道详细输出内容可以自行尝试。我们再次将model_ft.fc输出,得到如下结果:

Linear(in_features=512, out_features=1000, bias=True)

可以看到,输出的是model_ft的fc层,那么in_fetures就是fc的输入个数,知道了最后一层的输入个数之后我们就能对预训练网络的最后一层进行替换,替换成我们想要的输出分类数目,例如我们最后的分类结果是一个二分类,则可以将代码写成

model_ft.fc = nn.Linear(num_ftrs, 2)

其中num_ftres = model_ft.fc.in_features,这样我们的模型就建立好了。

2.以VGG11为例

VGG模型包含一个features模块和classifier,以VGG13为例,我们输出VGG的classifier模块,得到如下输出:

(classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )

所以需要对最后一个linear层进行替换,其思想与resnet网络类似:

num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, 2)

最后得到的model_ft模型即为输出为2的VGG13模型。

参考内容:
pytorch迁移学习官方教程

Pytorch预训练模型以及修改

相关文章

网友评论

      本文标题:如何使用pytorch torchvision.models中的

      本文链接:https://www.haomeiwen.com/subject/qslzghtx.html