美文网首页
pytorch-lightning baseline

pytorch-lightning baseline

作者: nnlrl | 来源:发表于2020-05-11 19:07 被阅读0次

使用pytorch-lightning进行图片分类

pytorch-lightning是基于pytorch的API封装,可以节省很多重复的代码,同时又具有pytorch的灵活性

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import torch.nn.functional as F
from collections import OrderedDict

import torch
import torchvision

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
from PIL import Image
import os
from tqdm import tqdm

print(pl.__version__)
print(torch.__version__)

## 0.7.5
## 1.4.0

使用kaggle上的植物分类数据,卷积层使用Resnet18,在后面连一个fc层,输出4种类别

# 超参数
class config:
    BATCH_SIZE = 32
    BASE_DIR = '/path/plant-pathology-2020-fgvc7/'
    LR = 0.001
    
# 定义主体网络
class PlantNet(torch.nn.Module):
    def __init__(self, out_features, pretrained):
        super().__init__()
        self.model = torchvision.models.resnet18(pretrained=pretrained)
        for p in self.parameters():
          p.requires_grad = False   #预训练模型加载进来后全部设置为不更新参数,然后再后面加层

        in_features = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(in_features, out_features)
        
        
    def forward(self, x):
        return F.log_softmax(self.model(x), dim=1)
    
# 定义Dataset
class PlantDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, transforms):
        self.df = pd.read_csv(csv_file)
        self.transforms = transforms
        
    def __getitem__(self, index):
        image_ids = self.df['image_id'].values
        labels = self.df[self.df.columns[1:]].values
        
        image = Image.open(config.BASE_DIR + 'images/' + image_ids[index] + '.jpg')
        label = torch.argmax(torch.tensor(labels[index]))
        
        if self.transforms:
            image = self.transforms(image)
        
        return image, label
    
    def __len__(self):
        return len(self.df)
    
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, transforms):
        self.df = pd.read_csv(csv_file)
        self.transforms = transforms
        
    def __getitem__(self, index):
        image_ids = self.df['image_id'].values
        image = Image.open(config.BASE_DIR + 'images/' + image_ids[index] + '.jpg')
        
        if self.transforms:
            image = self.transforms(image)
        return image
        
    def __len__(self):
        return len(self.df)

    
loader = PlantDataset(config.BASE_DIR + 'train.csv', transforms=False)
loader[1][0]

使用LightningModule定义训练过程

class PlantLightning(pl.LightningModule):
    def __init__(self, csv_file, pretrained):
        super().__init__()
        self.model = PlantNet(4, pretrained=pretrained)
        self.csv_file = csv_file
        self.best_loss = 10

    # 准备数据并使用数据增强
    def prepare_data(self, valid_size=0.2, random_seed=42, shuffle=True):
        transforms = {
            'train': torchvision.transforms.Compose(
                [
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.Resize((256, 256)),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]
                    )
                ]
            ),
            'valid': torchvision.transforms.Compose(
                [
                    torchvision.transforms.RandomHorizontalFlip(),
                    torchvision.transforms.Resize((256, 256)),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]
                    )
                ]
            ),
        }
        
        self.train_dataset = PlantDataset(csv_file=self.csv_file, 
                                          transforms=transforms['train'])
        self.valid_dataset = PlantDataset(csv_file=self.csv_file,
                                          transforms=transforms['valid'])
        
        num_train = len(self.train_dataset)
        indices = list(range(num_train))
        split = int(np.floor(valid_size * num_train))
        
        if shuffle:
            np.random.seed(random_seed)
            np.random.shuffle(indices)
        
        train_idx, valid_idx = indices[split:], indices[:split]
        self.train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
        self.valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx)
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, 
                                           batch_size=config.BATCH_SIZE,
                                           sampler=self.train_sampler,
                                           num_workers=2)
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.valid_dataset,
                                           batch_size=config.BATCH_SIZE,
                                           sampler=self.valid_sampler,
                                           num_workers=2)
    def forward(self, x):
        return self.model(x)

    #定义adam优化器并在一定批次之后执行学习率衰减
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=config.LR)
        stepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer, [4, 6, 8], gamma=0.1)
        
        return [optimizer], [stepLR]
    
    #定义训练过程,类似'for step, (data, labels) in enumerate(dataloader)'中执行的代码
    def training_step(self, batch, batch_idx):
        images, labels = batch

        outputs = self.forward(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        
        with torch.no_grad():
            acc = (torch.argmax(outputs, dim=1) == labels).float().mean()
            
        tqdm_dict = {'train_accuracy': acc}

        #在训练过程中传递的一些参数,用字典形式传递,'loss'是必须的,其余是可选项
        output = OrderedDict({'loss': loss,
                              'num_correct': acc,
                              'log': tqdm_dict,
                              'progress_bar': tqdm_dict})

        return output
    
    #在一个training_epoch结尾进行的操作
    def training_epoch_end(self, outputs):
        """Compute and log training loss and accuracy at the epoch level."""

        train_loss_mean = torch.stack([output['loss']
                                       for output in outputs]).mean()
        train_acc_mean = torch.stack([output['num_correct']
                                      for output in outputs]).sum().float()
        train_acc_mean /= (len(outputs) * config.BATCH_SIZE)
        return {'log': {'train_loss': train_loss_mean,
                        'train_acc': train_acc_mean,
                        'step': self.current_epoch}}
    
    # 执行验证步骤
    def validation_step(self, batch, batch_idx):
        images, labels = batch

        outputs = self.forward(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        acc = (torch.argmax(outputs, dim=1) == labels).float().mean()

        return {'val_loss': loss,
                'num_correct': acc}       
    
    #在一个validation_epoch执行的操作
    def validation_epoch_end(self, outputs):
        # 计算一个epoch的平均loss
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        
        val_loss_mean = torch.stack([output['val_loss']
                                     for output in outputs]).mean()
        # 计算一个epoch的平均准确度
        val_acc_mean = torch.stack([output['num_correct']
                                    for output in outputs]).sum().float()
        val_acc_mean /= len(outputs)
        print(f'Validation Loss: {avg_loss}, Validation Accuracy: {val_acc_mean}')
        
#         保存最佳模型的权重,也可以使用ModelCheckpoint
#         if avg_loss < self.best_loss:
#             self.best_loss = avg_loss
#             torch.save({'best_loss': avg_loss, 'model': self.model, 
#                         'model_state_dict': self.model.state_dict()},
#                         'best_model.pt')

        return {'log': {'val_loss': val_loss_mean,
                        'val_acc': val_acc_mean}}

定义好LightningModule之后就可以进行训练了,类似keras中的model.fit一样,pytorch-lightning也是一步进行训练同时在pl.Trainer可以指定多gpu训练,分布式训练, callback等操作

model = PlantLightning(config.BASE_DIR + 'train.csv', pretrained=True)
early_stopping = EarlyStopping('val_loss', patience=3)

trainer = pl.Trainer(gpus=[0],max_nb_epochs=10, early_stop_callback=early_stopping)
trainer.fit(model)
#开始训练,首先打印模型的结构,随后进行一系列检查之后开始训练模型
INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:
   | Name                              | Type              | Params
--------------------------------------------------------------------
0  | model                             | PlantNet          | 11 M  
1  | model.model                       | ResNet            | 11 M  
2  | model.model.conv1                 | Conv2d            | 9 K   
3  | model.model.bn1                   | BatchNorm2d       | 128   
4  | model.model.relu                  | ReLU              | 0     
5  | model.model.maxpool               | MaxPool2d         | 0     
6  | model.model.layer1                | Sequential        | 147 K 
7  | model.model.layer1.0              | BasicBlock        | 73 K  
8  | model.model.layer1.0.conv1        | Conv2d            | 36 K  
9  | model.model.layer1.0.bn1          | BatchNorm2d       | 128   
10 | model.model.layer1.0.relu         | ReLU              | 0     
11 | model.model.layer1.0.conv2        | Conv2d            | 36 K  
12 | model.model.layer1.0.bn2          | BatchNorm2d       | 128   
13 | model.model.layer1.1              | BasicBlock        | 73 K  
14 | model.model.layer1.1.conv1        | Conv2d            | 36 K  
15 | model.model.layer1.1.bn1          | BatchNorm2d       | 128   
16 | model.model.layer1.1.relu         | ReLU              | 0     
17 | model.model.layer1.1.conv2        | Conv2d            | 36 K  
18 | model.model.layer1.1.bn2          | BatchNorm2d       | 128   
19 | model.model.layer2                | Sequential        | 525 K 
20 | model.model.layer2.0              | BasicBlock        | 230 K 
21 | model.model.layer2.0.conv1        | Conv2d            | 73 K  
22 | model.model.layer2.0.bn1          | BatchNorm2d       | 256   
23 | model.model.layer2.0.relu         | ReLU              | 0     
24 | model.model.layer2.0.conv2        | Conv2d            | 147 K 
25 | model.model.layer2.0.bn2          | BatchNorm2d       | 256   
26 | model.model.layer2.0.downsample   | Sequential        | 8 K   
27 | model.model.layer2.0.downsample.0 | Conv2d            | 8 K   
28 | model.model.layer2.0.downsample.1 | BatchNorm2d       | 256   
29 | model.model.layer2.1              | BasicBlock        | 295 K 
30 | model.model.layer2.1.conv1        | Conv2d            | 147 K 
31 | model.model.layer2.1.bn1          | BatchNorm2d       | 256   
32 | model.model.layer2.1.relu         | ReLU              | 0     
33 | model.model.layer2.1.conv2        | Conv2d            | 147 K 
34 | model.model.layer2.1.bn2          | BatchNorm2d       | 256   
35 | model.model.layer3                | Sequential        | 2 M   
36 | model.model.layer3.0              | BasicBlock        | 919 K 
37 | model.model.layer3.0.conv1        | Conv2d            | 294 K 
38 | model.model.layer3.0.bn1          | BatchNorm2d       | 512   
39 | model.model.layer3.0.relu         | ReLU              | 0     
40 | model.model.layer3.0.conv2        | Conv2d            | 589 K 
41 | model.model.layer3.0.bn2          | BatchNorm2d       | 512   
42 | model.model.layer3.0.downsample   | Sequential        | 33 K  
43 | model.model.layer3.0.downsample.0 | Conv2d            | 32 K  
44 | model.model.layer3.0.downsample.1 | BatchNorm2d       | 512   
45 | model.model.layer3.1              | BasicBlock        | 1 M   
46 | model.model.layer3.1.conv1        | Conv2d            | 589 K 
47 | model.model.layer3.1.bn1          | BatchNorm2d       | 512   
48 | model.model.layer3.1.relu         | ReLU              | 0     
49 | model.model.layer3.1.conv2        | Conv2d            | 589 K 
50 | model.model.layer3.1.bn2          | BatchNorm2d       | 512   
51 | model.model.layer4                | Sequential        | 8 M   
52 | model.model.layer4.0              | BasicBlock        | 3 M   
53 | model.model.layer4.0.conv1        | Conv2d            | 1 M   
54 | model.model.layer4.0.bn1          | BatchNorm2d       | 1 K   
55 | model.model.layer4.0.relu         | ReLU              | 0     
56 | model.model.layer4.0.conv2        | Conv2d            | 2 M   
57 | model.model.layer4.0.bn2          | BatchNorm2d       | 1 K   
58 | model.model.layer4.0.downsample   | Sequential        | 132 K 
59 | model.model.layer4.0.downsample.0 | Conv2d            | 131 K 
60 | model.model.layer4.0.downsample.1 | BatchNorm2d       | 1 K   
61 | model.model.layer4.1              | BasicBlock        | 4 M   
62 | model.model.layer4.1.conv1        | Conv2d            | 2 M   
63 | model.model.layer4.1.bn1          | BatchNorm2d       | 1 K   
64 | model.model.layer4.1.relu         | ReLU              | 0     
65 | model.model.layer4.1.conv2        | Conv2d            | 2 M   
66 | model.model.layer4.1.bn2          | BatchNorm2d       | 1 K   
67 | model.model.avgpool               | AdaptiveAvgPool2d | 0     
68 | model.model.fc                    | Linear            | 2 K   
在训练过程中显示进度条并打印loss和acc

验证步骤

test_df = pd.read_csv(config.BASE_DIR +'test.csv')

print('Loading pre-trained model')
model = PlantNet(4, False)
model_ckpt = torch.load('best_model.pt')
print(model.load_state_dict(model_ckpt['model_state_dict']))

print('Testing!')

test_dataset = TestDataset(config.BASE_DIR + 'test.csv', 
                           transforms=torchvision.transforms.Compose(
                                    [
                                        torchvision.transforms.Resize((256, 256)),
                                        torchvision.transforms.ToTensor()
                                    ]
                                )
                        )

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)

predictions = np.zeros((1, 4))
with torch.no_grad():
    for images in tqdm(test_dataloader):

        images = images.to('cuda')
        model = model.to('cuda')

        preds = torch.nn.functional.softmax(model(images), 1)

        predictions = np.concatenate((predictions, preds.cpu().detach().numpy()), 0)

output = pd.DataFrame(predictions, columns=['healthy', 'multiple_diseases', 'rust', 'scab'])
output.drop(0, inplace=True)
output.reset_index(drop=True, inplace=True)
output['image_id'] = test_df.image_id
output = output[['image_id', 'healthy', 'multiple_diseases', 'rust', 'scab']]

# output.to_csv('submission.csv', index=False)

参考至https://www.kaggle.com/abhiswain/pytorch-lightning-resnet50-simple-baseline

相关文章

  • pytorch-lightning baseline

    使用pytorch-lightning进行图片分类 pytorch-lightning是基于pytorch的API...

  • Text

    Baseline是基线,在Android中,文字的绘制都是从Baseline处开始的,Baseline往上至字符“...

  • Flutter 基础布局Widgets之Baseline、Asp

    Baseline概述 Baseline即根据child的baseline定位child的小部件,即使得不同的chi...

  • baseline

    bpr: l2 regularization 换成当下pos[batchsize,dim]的l2 norm。opt...

  • Baseline

    什么是基线?这也是一道面试题,如果你只说第一次治疗前的非缺失观测值,虽然没错,但是肯定不能让面试官满意。 大多数A...

  • spring boot flyway 配置说明(摘抄)

    flyway.baseline-description对执行迁移时基准版本的描述. flyway.baseline...

  • 一些css特性

    baseline 应用 display:inline-block; 的元素的 baseline,当其中有 inli...

  • Android - 常用图标大小

    记住: 缩放比例 + Baseline列相关图标的大小 ,就能其算出其他图标大小 MDPI (Baseline)H...

  • Pytorch-lightning入门实例

    本文将通过Colab平台及MINIST数据集指导你了解Pytorch-lightning的核心组成。 注意:任何的...

  • div span垂直居中的问题

    vertical-align的默认值是baseline baseline:将支持valign特性的对象的内容与基线...

网友评论

      本文标题:pytorch-lightning baseline

      本文链接:https://www.haomeiwen.com/subject/oalpnhtx.html