学习笔记11:预训练模型 - pbc的成长之路 - 博客园 (cnblogs.com)
image.png加载这一句,就会自动下载imagenet上面已经预训练好的模型
model = torchvision.models.vgg16(pretrained=True)
image.png
image.png
卷积基部分冻结,改变的只是分类器部分。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
import glob
from torchvision import transforms
from torch.utils import data
from PIL import Image
img_dir = r'./data/dataset2/*.jpg'
imgs = glob.glob(img_dir)
# imgs[:3]
species = ['cloudy', 'rain', 'shine', 'sunrise']
species_to_idx = dict((c, i) for i, c in enumerate(species))
# species_to_idx # {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
idx_to_species = dict((v, k) for k, v in species_to_idx.items())
# idx_to_species # {0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
labels = []
for img in imgs:
for i, c in enumerate(species):
if c in img:
labels.append(i)
transforms = transforms.Compose([
transforms.Resize((192, 192)), # 此处不能太小。原始图片224*224. 如果太小,经过几次池化之后,相片尺寸还不如卷积核大时就会报错
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
class WT_dataset(data.Dataset):
def __init__(self, imgs_path, lables):
self.imgs_path = imgs_path
self.lables = lables
def __getitem__(self, index):
img_path = self.imgs_path[index]
lable = self.lables[index]
pil_img = Image.open(img_path)
pil_img = pil_img.convert("RGB")
pil_img = transforms(pil_img)
return pil_img, lable
def __len__(self):
return len(self.imgs_path)
dataset = WT_dataset(imgs, labels)
count = len(dataset)
train_count = int(0.8*count)
test_count = count - train_count
train_dataset, test_dataset = data.random_split(dataset, [train_count, test_count])
BTACH_SIZE = 32
train_dl = torch.utils.data.DataLoader(
train_dataset,
batch_size=BTACH_SIZE,
shuffle=True
)
test_dl = torch.utils.data.DataLoader(
test_dataset,
batch_size=BTACH_SIZE,
)
imgs, labels = next(iter(train_dl))
im = imgs[0].permute(1, 2, 0)
im = im.numpy()
im = (im + 1)/2
model = torchvision.models.vgg16(pretrained=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
for param in model.features.parameters():
param.requires_grad = False
model.classifier[-1].out_features = 4
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.0005)
#此处使用训练代码
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_loss, correct = 0, 0
model.train()
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
train_loss /= num_batches
correct /= size
return train_loss, correct
def test(dataloader, model):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
return test_loss, correct
epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
epoch_test_loss, epoch_test_acc = test(test_dl, model)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
template = ("epoch:{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ,"
"test_loss: {:.5f}, test_acc: {:.1f}%")
print(template.format(
epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
print("Done!")
Using cpu device
epoch: 0, train_loss: 1.10158, train_acc: 78.5% ,test_loss: 0.24042, test_acc: 92.9%
epoch: 1, train_loss: 0.20001, train_acc: 94.0% ,test_loss: 0.21910, test_acc: 94.7%
epoch: 2, train_loss: 0.18219, train_acc: 96.5% ,test_loss: 0.22200, test_acc: 94.7%
epoch: 3, train_loss: 0.10290, train_acc: 98.1% ,test_loss: 0.24635, test_acc: 95.6%
epoch: 4, train_loss: 0.09968, train_acc: 97.9% ,test_loss: 0.20582, test_acc: 95.6%
epoch: 5, train_loss: 0.34369, train_acc: 98.0% ,test_loss: 0.21849, test_acc: 95.1%
epoch: 6, train_loss: 0.13371, train_acc: 97.9% ,test_loss: 0.16699, test_acc: 96.0%
epoch: 7, train_loss: 0.02434, train_acc: 99.4% ,test_loss: 0.29029, test_acc: 96.0%
epoch: 8, train_loss: 0.02411, train_acc: 99.3% ,test_loss: 0.39036, test_acc: 94.2%
epoch: 9, train_loss: 0.05763, train_acc: 98.8% ,test_loss: 0.45348, test_acc: 94.2%
epoch:10, train_loss: 0.04606, train_acc: 99.0% ,test_loss: 0.30255, test_acc: 95.6%
epoch:11, train_loss: 0.13957, train_acc: 97.4% ,test_loss: 0.80957, test_acc: 92.0%
epoch:12, train_loss: 0.07190, train_acc: 98.7% ,test_loss: 0.92938, test_acc: 93.8%
epoch:13, train_loss: 0.30628, train_acc: 97.4% ,test_loss: 1.01188, test_acc: 92.4%
epoch:14, train_loss: 0.12689, train_acc: 98.8% ,test_loss: 0.77229, test_acc: 94.7%
epoch:15, train_loss: 0.08193, train_acc: 99.1% ,test_loss: 0.96631, test_acc: 94.2%
epoch:16, train_loss: 0.18195, train_acc: 98.6% ,test_loss: 0.66575, test_acc: 93.8%
epoch:17, train_loss: 0.08035, train_acc: 99.1% ,test_loss: 0.97847, test_acc: 95.6%
epoch:18, train_loss: 0.27165, train_acc: 97.9% ,test_loss: 0.53643, test_acc: 95.6%
epoch:19, train_loss: 0.10356, train_acc: 99.0% ,test_loss: 0.50462, test_acc: 96.0%
epoch:20, train_loss: 0.03413, train_acc: 99.7% ,test_loss: 0.51757, test_acc: 96.4%
epoch:21, train_loss: 0.00082, train_acc: 100.0% ,test_loss: 0.68283, test_acc: 96.0%
epoch:22, train_loss: 0.04721, train_acc: 99.4% ,test_loss: 0.93065, test_acc: 95.1%
epoch:23, train_loss: 0.07050, train_acc: 99.4% ,test_loss: 0.84323, test_acc: 96.9%
epoch:24, train_loss: 0.04657, train_acc: 99.8% ,test_loss: 1.29815, test_acc: 94.7%
epoch:25, train_loss: 0.03012, train_acc: 99.9% ,test_loss: 1.36507, test_acc: 94.7%
epoch:26, train_loss: 0.05835, train_acc: 99.7% ,test_loss: 1.02783, test_acc: 93.8%
epoch:27, train_loss: 0.05285, train_acc: 99.6% ,test_loss: 1.44344, test_acc: 93.8%
epoch:28, train_loss: 0.22289, train_acc: 98.9% ,test_loss: 1.26510, test_acc: 95.6%
epoch:29, train_loss: 0.11005, train_acc: 99.4% ,test_loss: 1.69362, test_acc: 95.6%
Done!
网友评论