美文网首页
Pytorch之图像分割(多目标分割,Multi Object

Pytorch之图像分割(多目标分割,Multi Object

作者: 深思海数_willschang | 来源:发表于2021-09-07 15:42 被阅读0次

    示例调用预训练模型(deeplabv3_resnet101)对VOCSegmentation数据进行图像分割实验。

    • PyTorch的DeepLabv3-ResNet101语义分割模型是在COCO 2017训练集上的一个子集训练得到的,相当于PASCAL VOC数据集,支持20个类别。
    • Deeplabv3-ResNet101由具有ResNet-101主干的Deeplabv3模型构成。

    引入相关包

    %matplotlib inline
    import os
    import copy
    import numpy as np
    from skimage.segmentation import mark_boundaries
    import matplotlib.pylab as plt
    from PIL import Image   
    
    import torch
    from torch import nn
    from torch import optim
    from torchvision.datasets import VOCSegmentation
    from torchvision.transforms.functional import to_tensor, to_pil_image
    from torch.utils.data import DataLoader
    from torchvision.models.segmentation import deeplabv3_resnet101
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    

    构建数据 dataset

    class DemoVOCSegmentation(VOCSegmentation):
        def __getitem__(self, index):
            img = Image.open(self.images[index]).convert('RGB')
            target = Image.open(self.masks[index])
    
            if self.transforms is not None:
                augmented = self.transforms(image=np.array(img), mask=np.array(target))
                img = augmented['image']
                target = augmented['mask']                  
                target[target>20] = 0
    
            img = to_tensor(img)            
            target = torch.from_numpy(target).type(torch.long)
            return img, target
        
        
    from albumentations import (
        HorizontalFlip,
        Compose,
        Resize,
        Normalize)
    
    mean = [0.485, 0.456, 0.406] 
    std = [0.229, 0.224, 0.225]
    h, w = 520,520
    
    transform_train = Compose([ Resize(h,w),
                    HorizontalFlip(p=0.5), 
                    Normalize(mean=mean, std=std)])
    
    transform_val = Compose([ Resize(h,w),
                              Normalize(mean=mean, std=std)])
    
     数据地址
    path_data = "./data/mos/"    
    # 创建dataset
    train_ds = DemoVOCSegmentation(path_data, 
                    year='2012', 
                    image_set='train', 
                    download=False, 
                    transforms=transform_train) 
    print(len(train_ds))
    # 1464
    
    
    val_ds = DemoVOCSegmentation(path_data, 
                    year='2012', 
                    image_set='val', 
                    download=False, 
                    transforms=transform_val)
    print(len(val_ds)) #1449
    
    • 数据查看(可视化)
    np.random.seed(0)
    num_classes =21
    COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")
    
    def show_img_target(img, target):
        if torch.is_tensor(img):
            img = to_pil_image(img)
            target = target.numpy()
        for ll in range(num_classes):
            mask = (target==ll)
            img = mark_boundaries(np.array(img) , 
                                mask,
                                outline_color=COLORS[ll],
                                color=COLORS[ll])
        plt.imshow(img)
        
    
    def re_normalize (x, mean = mean, std= std):
        x_r= x.clone()
        for c, (mean_c, std_c) in enumerate(zip(mean, std)):
            x_r [c] *= std_c
            x_r [c] += mean_c
        return x_r
    
    
    
    img, mask = train_ds[6]
    print(img.shape, img.type(),torch.max(img))
    print(mask.shape, mask.type(),torch.max(mask))
    
    plt.figure(figsize=(20,20))
    
    img_r= re_normalize(img)
    plt.subplot(1, 3, 1) 
    plt.imshow(to_pil_image(img_r))
    
    plt.subplot(1, 3, 2) 
    plt.imshow(mask)
    
    plt.subplot(1, 3, 3) 
    show_img_target(img_r, mask)
    """
    torch.Size([3, 520, 520]) torch.FloatTensor tensor(2.6400)
    torch.Size([520, 520]) torch.LongTensor tensor(4)
    """
    
    image segmentation

    数据加载器及加载模型

    # dataloader
    train_dl = DataLoader(train_ds, batch_size=2, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=8, shuffle=False)
    # 加载预训练模型
    model=deeplabv3_resnet101(pretrained=True, num_classes=21)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model=model.to(device)
    # print(model)
    

    模型部署

    model.eval()
    with torch.no_grad():
        for xb, yb in val_dl:
            yb_pred = model(xb.to(device))
            yb_pred = yb_pred["out"].cpu()
            print(yb_pred.shape)    
            yb_pred = torch.argmax(yb_pred,axis=1)
            break
    print(yb_pred.shape)
    
    plt.figure(figsize=(20,20))
    
    n=4
    img, mask= xb[n], yb_pred[n]
    img_r= re_normalize(img)
    plt.subplot(1, 3, 1) 
    plt.imshow(to_pil_image(img_r))
    
    plt.subplot(1, 3, 2) 
    plt.imshow(mask)
    
    plt.subplot(1, 3, 3) 
    show_img_target(img_r, mask)
    """
    torch.Size([16, 21, 520, 520])
    torch.Size([16, 520, 520])
    """
    
    deploy model to predict

    模型训练(因为电脑显卡太低,微调训练无法实验测试)

    def get_lr(opt):
        for param_group in opt.param_groups:
            return param_group['lr']
    
    def loss_batch(loss_func, output, target, opt=None):   
        loss = loss_func(output, target)
        
        if opt is not None:
            opt.zero_grad()
            loss.backward()
            opt.step()
    
        return loss.item(), None
    
    # 训练模型
    def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
        running_loss = 0.0
        len_data = len(dataset_dl.dataset)
    
        for xb, yb in dataset_dl:
            xb = xb.to(device)
            yb = yb.to(device)
            
            output = model(xb)["out"]
            loss_b, _ = loss_batch(loss_func, output, yb, opt)
            running_loss += loss_b
            
            if sanity_check is True:
                break
        
        loss = running_loss / float(len_data)
        return loss, None
    
    def train_val(model, params):
        num_epochs=params["num_epochs"]
        loss_func=params["loss_func"]
        opt=params["optimizer"]
        train_dl=params["train_dl"]
        val_dl=params["val_dl"]
        sanity_check=params["sanity_check"]
        lr_scheduler=params["lr_scheduler"]
        path2weights=params["path2weights"]
        
        loss_history={
            "train": [],
            "val": []}
        
        metric_history={
            "train": [],
            "val": []}    
        
        
        best_model_wts = copy.deepcopy(model.state_dict())
        best_loss=float('inf')    
        
        for epoch in range(num_epochs):
            current_lr=get_lr(opt)
            print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   
    
            model.train()
            train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)
    
            loss_history["train"].append(train_loss)
            metric_history["train"].append(train_metric)
            
            model.eval()
            with torch.no_grad():
                val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)
           
            loss_history["val"].append(val_loss)
            metric_history["val"].append(val_metric)   
            
            if val_loss < best_loss:
                best_loss = val_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                
                torch.save(model.state_dict(), path2weights)
                print("Copied best model weights!")
                
            lr_scheduler.step(val_loss)
            if current_lr != get_lr(opt):
                print("Loading best model weights!")
                model.load_state_dict(best_model_wts) 
                
            print("train loss: %.6f" %(train_loss))
            print("val loss: %.6f" %(val_loss))
            print("-"*10) 
        model.load_state_dict(best_model_wts)
        return model, loss_history, metric_history
    
    • 训练模型
    criterion = nn.CrossEntropyLoss(reduction="sum")
    opt = optim.Adam(model.parameters(), lr=1e-6)
    lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
    
    path2models= "./models/mos/"
    if not os.path.exists(path2models):
            os.mkdir(path2models)
    
    params_train={
        "num_epochs": 10,
        "optimizer": opt,
        "loss_func": criterion,
        "train_dl": train_dl,
        "val_dl": val_dl,
        "sanity_check": True,
        "lr_scheduler": lr_scheduler,
        "path2weights": path2models+"sanity_weights.pt",
    }
    
    model, loss_hist, _ = train_val(model, params_train)
    
    • 可视化结果
    num_epochs=params_train["num_epochs"]
    
    plt.title("Train-Val Loss")
    plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
    plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
    plt.ylabel("Loss")
    plt.xlabel("Training Epochs")
    plt.legend()
    plt.show()
    
    image.png

    相关文章

      网友评论

          本文标题:Pytorch之图像分割(多目标分割,Multi Object

          本文链接:https://www.haomeiwen.com/subject/ggcpwltx.html