使用pytorch-lightning
进行图片分类
pytorch-lightning
是基于pytorch
的API封装,可以节省很多重复的代码,同时又具有pytorch的灵活性
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import torch.nn.functional as F
from collections import OrderedDict
import torch
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
from PIL import Image
import os
from tqdm import tqdm
print(pl.__version__)
print(torch.__version__)
## 0.7.5
## 1.4.0
使用kaggle上的植物分类数据,卷积层使用Resnet18,在后面连一个fc层,输出4种类别
# 超参数
class config:
BATCH_SIZE = 32
BASE_DIR = '/path/plant-pathology-2020-fgvc7/'
LR = 0.001
# 定义主体网络
class PlantNet(torch.nn.Module):
def __init__(self, out_features, pretrained):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=pretrained)
for p in self.parameters():
p.requires_grad = False #预训练模型加载进来后全部设置为不更新参数,然后再后面加层
in_features = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features, out_features)
def forward(self, x):
return F.log_softmax(self.model(x), dim=1)
# 定义Dataset
class PlantDataset(torch.utils.data.Dataset):
def __init__(self, csv_file, transforms):
self.df = pd.read_csv(csv_file)
self.transforms = transforms
def __getitem__(self, index):
image_ids = self.df['image_id'].values
labels = self.df[self.df.columns[1:]].values
image = Image.open(config.BASE_DIR + 'images/' + image_ids[index] + '.jpg')
label = torch.argmax(torch.tensor(labels[index]))
if self.transforms:
image = self.transforms(image)
return image, label
def __len__(self):
return len(self.df)
class TestDataset(torch.utils.data.Dataset):
def __init__(self, csv_file, transforms):
self.df = pd.read_csv(csv_file)
self.transforms = transforms
def __getitem__(self, index):
image_ids = self.df['image_id'].values
image = Image.open(config.BASE_DIR + 'images/' + image_ids[index] + '.jpg')
if self.transforms:
image = self.transforms(image)
return image
def __len__(self):
return len(self.df)
loader = PlantDataset(config.BASE_DIR + 'train.csv', transforms=False)
loader[1][0]
使用LightningModule
定义训练过程
class PlantLightning(pl.LightningModule):
def __init__(self, csv_file, pretrained):
super().__init__()
self.model = PlantNet(4, pretrained=pretrained)
self.csv_file = csv_file
self.best_loss = 10
# 准备数据并使用数据增强
def prepare_data(self, valid_size=0.2, random_seed=42, shuffle=True):
transforms = {
'train': torchvision.transforms.Compose(
[
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
),
'valid': torchvision.transforms.Compose(
[
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
),
}
self.train_dataset = PlantDataset(csv_file=self.csv_file,
transforms=transforms['train'])
self.valid_dataset = PlantDataset(csv_file=self.csv_file,
transforms=transforms['valid'])
num_train = len(self.train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
self.train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
self.valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_dataset,
batch_size=config.BATCH_SIZE,
sampler=self.train_sampler,
num_workers=2)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.valid_dataset,
batch_size=config.BATCH_SIZE,
sampler=self.valid_sampler,
num_workers=2)
def forward(self, x):
return self.model(x)
#定义adam优化器并在一定批次之后执行学习率衰减
def configure_optimizers(self):
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=config.LR)
stepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer, [4, 6, 8], gamma=0.1)
return [optimizer], [stepLR]
#定义训练过程,类似'for step, (data, labels) in enumerate(dataloader)'中执行的代码
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self.forward(images)
loss = torch.nn.functional.cross_entropy(outputs, labels)
with torch.no_grad():
acc = (torch.argmax(outputs, dim=1) == labels).float().mean()
tqdm_dict = {'train_accuracy': acc}
#在训练过程中传递的一些参数,用字典形式传递,'loss'是必须的,其余是可选项
output = OrderedDict({'loss': loss,
'num_correct': acc,
'log': tqdm_dict,
'progress_bar': tqdm_dict})
return output
#在一个training_epoch结尾进行的操作
def training_epoch_end(self, outputs):
"""Compute and log training loss and accuracy at the epoch level."""
train_loss_mean = torch.stack([output['loss']
for output in outputs]).mean()
train_acc_mean = torch.stack([output['num_correct']
for output in outputs]).sum().float()
train_acc_mean /= (len(outputs) * config.BATCH_SIZE)
return {'log': {'train_loss': train_loss_mean,
'train_acc': train_acc_mean,
'step': self.current_epoch}}
# 执行验证步骤
def validation_step(self, batch, batch_idx):
images, labels = batch
outputs = self.forward(images)
loss = torch.nn.functional.cross_entropy(outputs, labels)
acc = (torch.argmax(outputs, dim=1) == labels).float().mean()
return {'val_loss': loss,
'num_correct': acc}
#在一个validation_epoch执行的操作
def validation_epoch_end(self, outputs):
# 计算一个epoch的平均loss
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
val_loss_mean = torch.stack([output['val_loss']
for output in outputs]).mean()
# 计算一个epoch的平均准确度
val_acc_mean = torch.stack([output['num_correct']
for output in outputs]).sum().float()
val_acc_mean /= len(outputs)
print(f'Validation Loss: {avg_loss}, Validation Accuracy: {val_acc_mean}')
# 保存最佳模型的权重,也可以使用ModelCheckpoint
# if avg_loss < self.best_loss:
# self.best_loss = avg_loss
# torch.save({'best_loss': avg_loss, 'model': self.model,
# 'model_state_dict': self.model.state_dict()},
# 'best_model.pt')
return {'log': {'val_loss': val_loss_mean,
'val_acc': val_acc_mean}}
定义好LightningModule之后就可以进行训练了,类似keras中的model.fit
一样,pytorch-lightning也是一步进行训练同时在pl.Trainer
可以指定多gpu训练,分布式训练, callback等操作
model = PlantLightning(config.BASE_DIR + 'train.csv', pretrained=True)
early_stopping = EarlyStopping('val_loss', patience=3)
trainer = pl.Trainer(gpus=[0],max_nb_epochs=10, early_stop_callback=early_stopping)
trainer.fit(model)
#开始训练,首先打印模型的结构,随后进行一系列检查之后开始训练模型
INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:
| Name | Type | Params
--------------------------------------------------------------------
0 | model | PlantNet | 11 M
1 | model.model | ResNet | 11 M
2 | model.model.conv1 | Conv2d | 9 K
3 | model.model.bn1 | BatchNorm2d | 128
4 | model.model.relu | ReLU | 0
5 | model.model.maxpool | MaxPool2d | 0
6 | model.model.layer1 | Sequential | 147 K
7 | model.model.layer1.0 | BasicBlock | 73 K
8 | model.model.layer1.0.conv1 | Conv2d | 36 K
9 | model.model.layer1.0.bn1 | BatchNorm2d | 128
10 | model.model.layer1.0.relu | ReLU | 0
11 | model.model.layer1.0.conv2 | Conv2d | 36 K
12 | model.model.layer1.0.bn2 | BatchNorm2d | 128
13 | model.model.layer1.1 | BasicBlock | 73 K
14 | model.model.layer1.1.conv1 | Conv2d | 36 K
15 | model.model.layer1.1.bn1 | BatchNorm2d | 128
16 | model.model.layer1.1.relu | ReLU | 0
17 | model.model.layer1.1.conv2 | Conv2d | 36 K
18 | model.model.layer1.1.bn2 | BatchNorm2d | 128
19 | model.model.layer2 | Sequential | 525 K
20 | model.model.layer2.0 | BasicBlock | 230 K
21 | model.model.layer2.0.conv1 | Conv2d | 73 K
22 | model.model.layer2.0.bn1 | BatchNorm2d | 256
23 | model.model.layer2.0.relu | ReLU | 0
24 | model.model.layer2.0.conv2 | Conv2d | 147 K
25 | model.model.layer2.0.bn2 | BatchNorm2d | 256
26 | model.model.layer2.0.downsample | Sequential | 8 K
27 | model.model.layer2.0.downsample.0 | Conv2d | 8 K
28 | model.model.layer2.0.downsample.1 | BatchNorm2d | 256
29 | model.model.layer2.1 | BasicBlock | 295 K
30 | model.model.layer2.1.conv1 | Conv2d | 147 K
31 | model.model.layer2.1.bn1 | BatchNorm2d | 256
32 | model.model.layer2.1.relu | ReLU | 0
33 | model.model.layer2.1.conv2 | Conv2d | 147 K
34 | model.model.layer2.1.bn2 | BatchNorm2d | 256
35 | model.model.layer3 | Sequential | 2 M
36 | model.model.layer3.0 | BasicBlock | 919 K
37 | model.model.layer3.0.conv1 | Conv2d | 294 K
38 | model.model.layer3.0.bn1 | BatchNorm2d | 512
39 | model.model.layer3.0.relu | ReLU | 0
40 | model.model.layer3.0.conv2 | Conv2d | 589 K
41 | model.model.layer3.0.bn2 | BatchNorm2d | 512
42 | model.model.layer3.0.downsample | Sequential | 33 K
43 | model.model.layer3.0.downsample.0 | Conv2d | 32 K
44 | model.model.layer3.0.downsample.1 | BatchNorm2d | 512
45 | model.model.layer3.1 | BasicBlock | 1 M
46 | model.model.layer3.1.conv1 | Conv2d | 589 K
47 | model.model.layer3.1.bn1 | BatchNorm2d | 512
48 | model.model.layer3.1.relu | ReLU | 0
49 | model.model.layer3.1.conv2 | Conv2d | 589 K
50 | model.model.layer3.1.bn2 | BatchNorm2d | 512
51 | model.model.layer4 | Sequential | 8 M
52 | model.model.layer4.0 | BasicBlock | 3 M
53 | model.model.layer4.0.conv1 | Conv2d | 1 M
54 | model.model.layer4.0.bn1 | BatchNorm2d | 1 K
55 | model.model.layer4.0.relu | ReLU | 0
56 | model.model.layer4.0.conv2 | Conv2d | 2 M
57 | model.model.layer4.0.bn2 | BatchNorm2d | 1 K
58 | model.model.layer4.0.downsample | Sequential | 132 K
59 | model.model.layer4.0.downsample.0 | Conv2d | 131 K
60 | model.model.layer4.0.downsample.1 | BatchNorm2d | 1 K
61 | model.model.layer4.1 | BasicBlock | 4 M
62 | model.model.layer4.1.conv1 | Conv2d | 2 M
63 | model.model.layer4.1.bn1 | BatchNorm2d | 1 K
64 | model.model.layer4.1.relu | ReLU | 0
65 | model.model.layer4.1.conv2 | Conv2d | 2 M
66 | model.model.layer4.1.bn2 | BatchNorm2d | 1 K
67 | model.model.avgpool | AdaptiveAvgPool2d | 0
68 | model.model.fc | Linear | 2 K
在训练过程中显示进度条并打印loss和acc
验证步骤
test_df = pd.read_csv(config.BASE_DIR +'test.csv')
print('Loading pre-trained model')
model = PlantNet(4, False)
model_ckpt = torch.load('best_model.pt')
print(model.load_state_dict(model_ckpt['model_state_dict']))
print('Testing!')
test_dataset = TestDataset(config.BASE_DIR + 'test.csv',
transforms=torchvision.transforms.Compose(
[
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor()
]
)
)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)
predictions = np.zeros((1, 4))
with torch.no_grad():
for images in tqdm(test_dataloader):
images = images.to('cuda')
model = model.to('cuda')
preds = torch.nn.functional.softmax(model(images), 1)
predictions = np.concatenate((predictions, preds.cpu().detach().numpy()), 0)
output = pd.DataFrame(predictions, columns=['healthy', 'multiple_diseases', 'rust', 'scab'])
output.drop(0, inplace=True)
output.reset_index(drop=True, inplace=True)
output['image_id'] = test_df.image_id
output = output[['image_id', 'healthy', 'multiple_diseases', 'rust', 'scab']]
# output.to_csv('submission.csv', index=False)
参考至https://www.kaggle.com/abhiswain/pytorch-lightning-resnet50-simple-baseline
网友评论