1、PyTorch学习笔记(1)-finetune网络的一些注意事项https://blog.csdn.net/u014448054/article/details/80623514k+
2、全局微调和局部微调
https://blog.csdn.net/u012759136/article/details/65634477
3、局部微调选择层数
https://blog.csdn.net/u012436149/article/details/78038098
4、CS231关于transfer learning的教程
http://cs231n.github.io/transfer-learning/
5、Pytorch Tutorial for Fine Tuning/Transfer Learning a Resnet for Image Classification
https://github.com/Spandan-Madan/Pytorch_fine_tuning_Tutorial
6、pytorch官方教程
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
https://pytorch.org/docs/stable/torchvision/models.html
github网络配置文件
https://github.com/pytorch/vision/tree/master/torchvision/models
例程:
import torch
from torchvisionimport models
model = models.resnet18(pretrained=True)
# #######全局finetune,使用不同的learning rate#########
# ignored_params = list(map(id,model.fc.parameters()))
# base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
#
# optimizer = torch.optim.SGD([{'params': base_params}, {'params': model.fc.parameters(), 'lr': 1e-5}], lr=1e-4, momentum=0.9)
#
#######局部finetune,只学习fc#######################
# in_features = model.fc.in_features
# #######change the output feature of fc######
# model.fc = torch.nn.Linear(in_features, 15)
# for param in model.parameters():
# param.requires_grad = False
# model.fc = torch.nn.Linear(in_features, 15)
# optimizer = torch.optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
#
######局部finetune, 学习fc和top parameters#########
fc_params = list(map(id,model.fc.parameters()))
top_params = list(map(id, model.layer4.parameters()))
ignored_params = fc_params.append(top_params)
base_params_model = filter(lambda p: id(p)not in ignored_params, model.parameters())
fc_params_model = filter(lambda p: id(p)in fc_params, model.parameters())
top_params_model = filter(lambda p: id(p)in top_params, model.parameters())
optimizer = torch.optim.SGD([{'params': fc_params_model}, {'params': top_params_model, 'lr':1e-3}], lr=1e-2, momentum=0.9)
报错:fc size mismatch
inception_v3 RuntimeError: size mismatch
原因是input image的size应该是224(vgg、resnet)或者299(inception)不是448
网友评论