美文网首页
pytorch-lightning baseline

pytorch-lightning baseline

作者: nnlrl | 来源:发表于2020-05-11 19:07 被阅读0次

    使用pytorch-lightning进行图片分类

    pytorch-lightning是基于pytorch的API封装,可以节省很多重复的代码,同时又具有pytorch的灵活性

    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
    import torch.nn.functional as F
    from collections import OrderedDict
    
    import torch
    import torchvision
    
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import glob
    from PIL import Image
    import os
    from tqdm import tqdm
    
    print(pl.__version__)
    print(torch.__version__)
    
    ## 0.7.5
    ## 1.4.0
    

    使用kaggle上的植物分类数据,卷积层使用Resnet18,在后面连一个fc层,输出4种类别

    # 超参数
    class config:
        BATCH_SIZE = 32
        BASE_DIR = '/path/plant-pathology-2020-fgvc7/'
        LR = 0.001
        
    # 定义主体网络
    class PlantNet(torch.nn.Module):
        def __init__(self, out_features, pretrained):
            super().__init__()
            self.model = torchvision.models.resnet18(pretrained=pretrained)
            for p in self.parameters():
              p.requires_grad = False   #预训练模型加载进来后全部设置为不更新参数,然后再后面加层
    
            in_features = self.model.fc.in_features
            self.model.fc = torch.nn.Linear(in_features, out_features)
            
            
        def forward(self, x):
            return F.log_softmax(self.model(x), dim=1)
        
    # 定义Dataset
    class PlantDataset(torch.utils.data.Dataset):
        def __init__(self, csv_file, transforms):
            self.df = pd.read_csv(csv_file)
            self.transforms = transforms
            
        def __getitem__(self, index):
            image_ids = self.df['image_id'].values
            labels = self.df[self.df.columns[1:]].values
            
            image = Image.open(config.BASE_DIR + 'images/' + image_ids[index] + '.jpg')
            label = torch.argmax(torch.tensor(labels[index]))
            
            if self.transforms:
                image = self.transforms(image)
            
            return image, label
        
        def __len__(self):
            return len(self.df)
        
    class TestDataset(torch.utils.data.Dataset):
        def __init__(self, csv_file, transforms):
            self.df = pd.read_csv(csv_file)
            self.transforms = transforms
            
        def __getitem__(self, index):
            image_ids = self.df['image_id'].values
            image = Image.open(config.BASE_DIR + 'images/' + image_ids[index] + '.jpg')
            
            if self.transforms:
                image = self.transforms(image)
            return image
            
        def __len__(self):
            return len(self.df)
    
        
    loader = PlantDataset(config.BASE_DIR + 'train.csv', transforms=False)
    loader[1][0]
    

    使用LightningModule定义训练过程

    class PlantLightning(pl.LightningModule):
        def __init__(self, csv_file, pretrained):
            super().__init__()
            self.model = PlantNet(4, pretrained=pretrained)
            self.csv_file = csv_file
            self.best_loss = 10
    
        # 准备数据并使用数据增强
        def prepare_data(self, valid_size=0.2, random_seed=42, shuffle=True):
            transforms = {
                'train': torchvision.transforms.Compose(
                    [
                        torchvision.transforms.RandomHorizontalFlip(),
                        torchvision.transforms.Resize((256, 256)),
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize(
                            mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]
                        )
                    ]
                ),
                'valid': torchvision.transforms.Compose(
                    [
                        torchvision.transforms.RandomHorizontalFlip(),
                        torchvision.transforms.Resize((256, 256)),
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize(
                            mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]
                        )
                    ]
                ),
            }
            
            self.train_dataset = PlantDataset(csv_file=self.csv_file, 
                                              transforms=transforms['train'])
            self.valid_dataset = PlantDataset(csv_file=self.csv_file,
                                              transforms=transforms['valid'])
            
            num_train = len(self.train_dataset)
            indices = list(range(num_train))
            split = int(np.floor(valid_size * num_train))
            
            if shuffle:
                np.random.seed(random_seed)
                np.random.shuffle(indices)
            
            train_idx, valid_idx = indices[split:], indices[:split]
            self.train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
            self.valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_idx)
            
        def train_dataloader(self):
            return torch.utils.data.DataLoader(self.train_dataset, 
                                               batch_size=config.BATCH_SIZE,
                                               sampler=self.train_sampler,
                                               num_workers=2)
        def val_dataloader(self):
            return torch.utils.data.DataLoader(self.valid_dataset,
                                               batch_size=config.BATCH_SIZE,
                                               sampler=self.valid_sampler,
                                               num_workers=2)
        def forward(self, x):
            return self.model(x)
    
        #定义adam优化器并在一定批次之后执行学习率衰减
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=config.LR)
            stepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer, [4, 6, 8], gamma=0.1)
            
            return [optimizer], [stepLR]
        
        #定义训练过程,类似'for step, (data, labels) in enumerate(dataloader)'中执行的代码
        def training_step(self, batch, batch_idx):
            images, labels = batch
    
            outputs = self.forward(images)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
            
            with torch.no_grad():
                acc = (torch.argmax(outputs, dim=1) == labels).float().mean()
                
            tqdm_dict = {'train_accuracy': acc}
    
            #在训练过程中传递的一些参数,用字典形式传递,'loss'是必须的,其余是可选项
            output = OrderedDict({'loss': loss,
                                  'num_correct': acc,
                                  'log': tqdm_dict,
                                  'progress_bar': tqdm_dict})
    
            return output
        
        #在一个training_epoch结尾进行的操作
        def training_epoch_end(self, outputs):
            """Compute and log training loss and accuracy at the epoch level."""
    
            train_loss_mean = torch.stack([output['loss']
                                           for output in outputs]).mean()
            train_acc_mean = torch.stack([output['num_correct']
                                          for output in outputs]).sum().float()
            train_acc_mean /= (len(outputs) * config.BATCH_SIZE)
            return {'log': {'train_loss': train_loss_mean,
                            'train_acc': train_acc_mean,
                            'step': self.current_epoch}}
        
        # 执行验证步骤
        def validation_step(self, batch, batch_idx):
            images, labels = batch
    
            outputs = self.forward(images)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
            acc = (torch.argmax(outputs, dim=1) == labels).float().mean()
    
            return {'val_loss': loss,
                    'num_correct': acc}       
        
        #在一个validation_epoch执行的操作
        def validation_epoch_end(self, outputs):
            # 计算一个epoch的平均loss
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
            
            val_loss_mean = torch.stack([output['val_loss']
                                         for output in outputs]).mean()
            # 计算一个epoch的平均准确度
            val_acc_mean = torch.stack([output['num_correct']
                                        for output in outputs]).sum().float()
            val_acc_mean /= len(outputs)
            print(f'Validation Loss: {avg_loss}, Validation Accuracy: {val_acc_mean}')
            
    #         保存最佳模型的权重,也可以使用ModelCheckpoint
    #         if avg_loss < self.best_loss:
    #             self.best_loss = avg_loss
    #             torch.save({'best_loss': avg_loss, 'model': self.model, 
    #                         'model_state_dict': self.model.state_dict()},
    #                         'best_model.pt')
    
            return {'log': {'val_loss': val_loss_mean,
                            'val_acc': val_acc_mean}}
    

    定义好LightningModule之后就可以进行训练了,类似keras中的model.fit一样,pytorch-lightning也是一步进行训练同时在pl.Trainer可以指定多gpu训练,分布式训练, callback等操作

    model = PlantLightning(config.BASE_DIR + 'train.csv', pretrained=True)
    early_stopping = EarlyStopping('val_loss', patience=3)
    
    trainer = pl.Trainer(gpus=[0],max_nb_epochs=10, early_stop_callback=early_stopping)
    trainer.fit(model)
    
    #开始训练,首先打印模型的结构,随后进行一系列检查之后开始训练模型
    INFO:lightning:GPU available: True, used: True
    INFO:lightning:CUDA_VISIBLE_DEVICES: [0]
    INFO:lightning:
       | Name                              | Type              | Params
    --------------------------------------------------------------------
    0  | model                             | PlantNet          | 11 M  
    1  | model.model                       | ResNet            | 11 M  
    2  | model.model.conv1                 | Conv2d            | 9 K   
    3  | model.model.bn1                   | BatchNorm2d       | 128   
    4  | model.model.relu                  | ReLU              | 0     
    5  | model.model.maxpool               | MaxPool2d         | 0     
    6  | model.model.layer1                | Sequential        | 147 K 
    7  | model.model.layer1.0              | BasicBlock        | 73 K  
    8  | model.model.layer1.0.conv1        | Conv2d            | 36 K  
    9  | model.model.layer1.0.bn1          | BatchNorm2d       | 128   
    10 | model.model.layer1.0.relu         | ReLU              | 0     
    11 | model.model.layer1.0.conv2        | Conv2d            | 36 K  
    12 | model.model.layer1.0.bn2          | BatchNorm2d       | 128   
    13 | model.model.layer1.1              | BasicBlock        | 73 K  
    14 | model.model.layer1.1.conv1        | Conv2d            | 36 K  
    15 | model.model.layer1.1.bn1          | BatchNorm2d       | 128   
    16 | model.model.layer1.1.relu         | ReLU              | 0     
    17 | model.model.layer1.1.conv2        | Conv2d            | 36 K  
    18 | model.model.layer1.1.bn2          | BatchNorm2d       | 128   
    19 | model.model.layer2                | Sequential        | 525 K 
    20 | model.model.layer2.0              | BasicBlock        | 230 K 
    21 | model.model.layer2.0.conv1        | Conv2d            | 73 K  
    22 | model.model.layer2.0.bn1          | BatchNorm2d       | 256   
    23 | model.model.layer2.0.relu         | ReLU              | 0     
    24 | model.model.layer2.0.conv2        | Conv2d            | 147 K 
    25 | model.model.layer2.0.bn2          | BatchNorm2d       | 256   
    26 | model.model.layer2.0.downsample   | Sequential        | 8 K   
    27 | model.model.layer2.0.downsample.0 | Conv2d            | 8 K   
    28 | model.model.layer2.0.downsample.1 | BatchNorm2d       | 256   
    29 | model.model.layer2.1              | BasicBlock        | 295 K 
    30 | model.model.layer2.1.conv1        | Conv2d            | 147 K 
    31 | model.model.layer2.1.bn1          | BatchNorm2d       | 256   
    32 | model.model.layer2.1.relu         | ReLU              | 0     
    33 | model.model.layer2.1.conv2        | Conv2d            | 147 K 
    34 | model.model.layer2.1.bn2          | BatchNorm2d       | 256   
    35 | model.model.layer3                | Sequential        | 2 M   
    36 | model.model.layer3.0              | BasicBlock        | 919 K 
    37 | model.model.layer3.0.conv1        | Conv2d            | 294 K 
    38 | model.model.layer3.0.bn1          | BatchNorm2d       | 512   
    39 | model.model.layer3.0.relu         | ReLU              | 0     
    40 | model.model.layer3.0.conv2        | Conv2d            | 589 K 
    41 | model.model.layer3.0.bn2          | BatchNorm2d       | 512   
    42 | model.model.layer3.0.downsample   | Sequential        | 33 K  
    43 | model.model.layer3.0.downsample.0 | Conv2d            | 32 K  
    44 | model.model.layer3.0.downsample.1 | BatchNorm2d       | 512   
    45 | model.model.layer3.1              | BasicBlock        | 1 M   
    46 | model.model.layer3.1.conv1        | Conv2d            | 589 K 
    47 | model.model.layer3.1.bn1          | BatchNorm2d       | 512   
    48 | model.model.layer3.1.relu         | ReLU              | 0     
    49 | model.model.layer3.1.conv2        | Conv2d            | 589 K 
    50 | model.model.layer3.1.bn2          | BatchNorm2d       | 512   
    51 | model.model.layer4                | Sequential        | 8 M   
    52 | model.model.layer4.0              | BasicBlock        | 3 M   
    53 | model.model.layer4.0.conv1        | Conv2d            | 1 M   
    54 | model.model.layer4.0.bn1          | BatchNorm2d       | 1 K   
    55 | model.model.layer4.0.relu         | ReLU              | 0     
    56 | model.model.layer4.0.conv2        | Conv2d            | 2 M   
    57 | model.model.layer4.0.bn2          | BatchNorm2d       | 1 K   
    58 | model.model.layer4.0.downsample   | Sequential        | 132 K 
    59 | model.model.layer4.0.downsample.0 | Conv2d            | 131 K 
    60 | model.model.layer4.0.downsample.1 | BatchNorm2d       | 1 K   
    61 | model.model.layer4.1              | BasicBlock        | 4 M   
    62 | model.model.layer4.1.conv1        | Conv2d            | 2 M   
    63 | model.model.layer4.1.bn1          | BatchNorm2d       | 1 K   
    64 | model.model.layer4.1.relu         | ReLU              | 0     
    65 | model.model.layer4.1.conv2        | Conv2d            | 2 M   
    66 | model.model.layer4.1.bn2          | BatchNorm2d       | 1 K   
    67 | model.model.avgpool               | AdaptiveAvgPool2d | 0     
    68 | model.model.fc                    | Linear            | 2 K   
    
    在训练过程中显示进度条并打印loss和acc

    验证步骤

    test_df = pd.read_csv(config.BASE_DIR +'test.csv')
    
    print('Loading pre-trained model')
    model = PlantNet(4, False)
    model_ckpt = torch.load('best_model.pt')
    print(model.load_state_dict(model_ckpt['model_state_dict']))
    
    print('Testing!')
    
    test_dataset = TestDataset(config.BASE_DIR + 'test.csv', 
                               transforms=torchvision.transforms.Compose(
                                        [
                                            torchvision.transforms.Resize((256, 256)),
                                            torchvision.transforms.ToTensor()
                                        ]
                                    )
                            )
    
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)
    
    predictions = np.zeros((1, 4))
    with torch.no_grad():
        for images in tqdm(test_dataloader):
    
            images = images.to('cuda')
            model = model.to('cuda')
    
            preds = torch.nn.functional.softmax(model(images), 1)
    
            predictions = np.concatenate((predictions, preds.cpu().detach().numpy()), 0)
    
    output = pd.DataFrame(predictions, columns=['healthy', 'multiple_diseases', 'rust', 'scab'])
    output.drop(0, inplace=True)
    output.reset_index(drop=True, inplace=True)
    output['image_id'] = test_df.image_id
    output = output[['image_id', 'healthy', 'multiple_diseases', 'rust', 'scab']]
    
    # output.to_csv('submission.csv', index=False)
    

    参考至https://www.kaggle.com/abhiswain/pytorch-lightning-resnet50-simple-baseline

    相关文章

      网友评论

          本文标题:pytorch-lightning baseline

          本文链接:https://www.haomeiwen.com/subject/oalpnhtx.html