美文网首页
PyTrch深度学习简明实战20 - 语义分割 -LinkNet

PyTrch深度学习简明实战20 - 语义分割 -LinkNet

作者: 薛东弗斯 | 来源:发表于2023-04-04 08:05 被阅读0次

    Resnet在图像分割领域的应用- Linknet

    image.png
    image.png
    image.png
    [1707.03718] LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation (arxiv.org)
    1707.03718.pdf (arxiv.org)
    image.png
    输出的full-conv 实际上是反卷积
    输入7x7的卷积核,输入特征3,64个卷积核,/2 表示图像缩放为原来的1/2
    max-pool 3x3卷积核,/2表示对图像缩放为1/2
    输出full-conv 反卷积,3x3卷积核,输入64个特征,输出32个特征,2表示图像放大为之前的2倍
    中间的conv 用3x3卷积核,输入32 输出32,没有对图片进行缩放,只是一种特征的提取
    最后输出full-conv反卷积,2x2 kernel,输入特征32,N个filter,N的取值取决于要分为多少类别,然后
    2对图片进行放大
    image.png
    输入 两个残差模块。
    conv[(3x3),(m,n),/2] 3x3卷积核,输入m个特征,n个卷积核,图像缩放为原来的一半 image.png
    image.png
    image.png
    image.png
    image.png
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils import data
    
    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    import torchvision
    from torchvision import transforms
    import os
    
    import glob
    from PIL import Image
    
    BATCH_SIZE = 128
    
    # 绘制原图
    # pil_img = Image.open('./data/hk/training/00001.png')
    # np_img = np.array(pil_img)
    # plt.imshow(np_img)
    # plt.show()
    
    # 绘制分割后的图
    # pil_img = Image.open('./data/hk/training/00001_matte.png')
    # np_img = np.array(pil_img)
    # plt.imshow(np_img)
    # plt.show()
    
    # np_img.max(), np_img.min()   # (255, 0)
    # np_img.shape    # (800, 600)
    # np.unique(np_img)    .., 255])  # array([  0, .., 255])  像素点0-255直接,不是2分类的0或者1
    
    # 绘制像素点为0/1的图片
    # pil_img = Image.open('./data/hk/training/00001_matte.png')
    # np_img = np.array(pil_img)
    # np_img[np_img>0]=1
    # plt.imshow(np_img)
    # plt.show()
    # np.unique(np_img)    # array([0, 1], dtype=uint8)   此时,像素只包括0和1. 这种变换对原有像素有一定的损失。
    
    all_pics = glob.glob('./data/hk/training/*.png')
    # all_pics[:5]
    # ['./data/hk/training\\00001.png',
    #  './data/hk/training\\00001_matte.png',
    #  './data/hk/training\\00002.png',
    #  './data/hk/training\\00002_matte.png',
    #  './data/hk/training\\00003.png']
    images = [p for p in all_pics if 'matte' not in p]
    # len(images)  # 1700
    annotations = [p for p in all_pics if 'matte' in p]
    # len(annotations)  # 1700
    
    np.random.seed(2021)
    index = np.random.permutation(len(images))
    images = np.array(images)[index]
    anno = np.array(annotations)[index]
    all_test_pics = glob.glob('./data/hk/testing/*.png')
    test_images = [p for p in all_test_pics if 'matte' not in p]
    test_anno = [p for p in all_test_pics if 'matte' in p]
    
    transform = transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
    ])
    
    class Portrait_dataset(data.Dataset):
        def __init__(self, img_paths, anno_paths):  # 需要提高图片路径+分割图路径
            self.imgs = img_paths
            self.annos = anno_paths
            
        def __getitem__(self, index):  # 切片
            img = self.imgs[index]
            anno = self.annos[index]
            
            pil_img = Image.open(img)    
            img_tensor = transform(pil_img)   # 通过transform转换为tensor。 对于原图的处理
            
            pil_anno = Image.open(anno)    
            anno_tensor = transform(pil_anno)
            anno_tensor = torch.squeeze(anno_tensor).type(torch.long)   # 默认转换后的尺寸是256*256*1,1个channel。 用squeeze去掉chanel
            anno_tensor[anno_tensor > 0] = 1      # 如果大于0,就置为1
            
            return img_tensor, anno_tensor
        
        def __len__(self):
            return len(self.imgs)
        
    train_dataset = Portrait_dataset(images, anno)
    test_dataset = Portrait_dataset(test_images, test_anno)
    
    train_dl = data.DataLoader(train_dataset,
                               batch_size=BATCH_SIZE,
                               shuffle=True,
    )
    
    test_dl = data.DataLoader(
                              test_dataset,
                              batch_size=BATCH_SIZE,
    )
    
    imgs_batch, annos_batch = next(iter(train_dl))   # 取出批次数据
    # img_batch.shape    # batch=8, channel=3, 大小256*256
    # annos_batch.shape    # batch=8, channel=1, 大小256*256. 用torch.squeese 就把torch为1的维度去掉了
    
    # img = imgs_batch[0].permute(1,2,0).numpy()    # 对第1张图片进行绘图。  permute将channel放到最后面
    # anno = annos_batch[0].numpy()                 # anno图片没有channle这个属性,因此不需要用permute
    
    # plt.subplot(1,2,1)                            # 绘制1行2列的第1张图
    # plt.imshow(img)
    # plt.subplot(1,2,2)                            # 绘制1行2列的第2张图
    # plt.imshow(anno)                     
    
    # 创建LinkNet模型
    # 1. 编写卷积模块(卷积 + BN + activate)
    # 2. 编写反卷积模块(反卷积+BN+activate)
    # 3. 编码器(4个卷积模块)
    # 4. 解码器(卷积模块+反卷积模块+卷积模块)
    # 5. 实现整体网络结构
    
    # 1. 卷积模块
    class ConvBlock(nn.Module):
        def __init__(self, in_channels, out_channels, 
                     k_size=3, 
                     stride=1,    # 通过stride来控制图片缩放
                     pad=1):      
            super(ConvBlock, self).__init__()   # 继承父类的属性
            self.conv_relu = nn.Sequential(
                                nn.Conv2d(in_channels, out_channels, 
                                          kernel_size=k_size,
                                          stride=stride,
                                          padding=pad),
                                nn.BatchNorm2d(out_channels),
                                nn.ReLU(inplace=True)   # 内存中不做备份,直接改变
                )
        def forward(self, x):
            x = self.conv_relu(x)
            return x
    
    # 2. 反卷积模块
    class DeconvBlock(nn.Module):
        def __init__(self, in_channels, out_channels, 
                     k_size=3, 
                     stride=2, 
                     pad=1,
                     padding=1):
            super(DeconvBlock, self).__init__()
            self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 
                                                kernel_size=k_size,
                                                stride=stride,
                                                padding=padding,
                                                output_padding=pad)
            self.bn = nn.BatchNorm2d(out_channels)    #  反卷积需要对激活与BN层进行控制。 
                
        def forward(self, x, is_act=True):      # is_act=True 表示需要激活。 如果为False,不进行BN与激活,直接返回反卷积结构
            x = self.deconv(x)
            if is_act:
                x = torch.relu(self.bn(x))
            return x
    
    # 3.编码器
    class EncodeBlock(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(EncodeBlock, self).__init__()
            self.conv1_1 = ConvBlock(in_channels, out_channels, stride=2)  # 第一个卷积层,输入m,输出n,输出需要缩放为1/2,所以stride=2
            self.conv1_2 = ConvBlock(out_channels, out_channels)           # 输入 输出相同,没有做缩放。 
            self.conv2_1 = ConvBlock(out_channels, out_channels)
            self.conv2_2 = ConvBlock(out_channels, out_channels)
            self.shortcut = ConvBlock(in_channels, out_channels, stride=2)   # shortcut连接。
    
        def forward(self, x):
            out1 = self.conv1_1(x)
            out1 = self.conv2_1(out1)
            residue = self.shortcut(x)
            out2 = self.conv2_1(out1 + residue)
            out2 = self.conv2_2(out2)
            return out2 + out1
        
    # 4. 解码器
    class DecodeBlock(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(DecodeBlock, self).__init__()
            self.conv1 = ConvBlock(in_channels, in_channels//4, 
                                   k_size=1, pad=0)
            self.deconv = DeconvBlock(in_channels//4, in_channels//4)
            self.conv2 = ConvBlock(in_channels//4, out_channels, 
                                   k_size=1, pad=0)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.deconv(x)
            x = self.conv2(x)
            return x
        
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.init_conv = ConvBlock(3, 64, 
                                       k_size=7, 
                                       stride=2,
                                       pad=3)
            self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
            
            self.encode1 = EncodeBlock(64, 64)
            self.encode2 = EncodeBlock(64, 128)
            self.encode3 = EncodeBlock(128, 256)
            self.encode4 = EncodeBlock(256, 512)
            
            self.decode4 = DecodeBlock(512, 256)
            self.decode3 = DecodeBlock(256, 128)
            self.decode2 = DecodeBlock(128, 64)
            self.decode1 = DecodeBlock(64, 64)
            
            self.deconv_last1 = DeconvBlock(64, 32)
            self.conv_last = ConvBlock(32, 32)
            self.deconv_last2 = DeconvBlock(32, 2, 
                                            k_size=2,
                                            pad=0,
                                            padding=0)
        
        def forward(self, x):
            x = self.init_conv(x)              #  (6, 128, 128, 64)
            x = self.init_maxpool(x)           #  (6, 64, 64, 64)
            
            e1 = self.encode1(x)               #  (6, 32, 32, 64)
            e2 = self.encode2(e1)              #  (6, 16, 16, 128)
            e3 = self.encode3(e2)              #  (6, 8, 8, 256)
            e4 = self.encode4(e3)              #  (6, 4, 4, 512)        
            
            d4 = self.decode4(e4) + e3
            d3 = self.decode3(d4) + e2
            d2 = self.decode2(d3) + e1
            d1 = self.decode1(d2)
            
            f1 = self.deconv_last1(d1)
            f2 = self.conv_last(f1)
            f3 = self.deconv_last2(f2, is_act=False)
            
            return f3
        
        
        
    model = Net()
    
    if torch.cuda.is_available():
        model.to('cuda')
        
    loss_fn = nn.CrossEntropyLoss()
    
    from torch.optim import lr_scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    
    def fit(epoch, model, trainloader, testloader):
        correct = 0
        total = 0
        running_loss = 0
        
        model.train()
        for x, y in trainloader:
            if torch.cuda.is_available():
                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                y_pred = torch.argmax(y_pred, dim=1)
                correct += (y_pred == y).sum().item()
                total += y.size(0)
                running_loss += loss.item()
        exp_lr_scheduler.step()
        epoch_loss = running_loss / len(trainloader.dataset)
        epoch_acc = correct / (total*256*256)
            
            
        test_correct = 0
        test_total = 0
        test_running_loss = 0 
        
        model.eval()
        with torch.no_grad():
            for x, y in testloader:
                if torch.cuda.is_available():
                    x, y = x.to('cuda'), y.to('cuda')
                y_pred = model(x)
                loss = loss_fn(y_pred, y)
                y_pred = torch.argmax(y_pred, dim=1)
                test_correct += (y_pred == y).sum().item()
                test_total += y.size(0)
                test_running_loss += loss.item()
        
        epoch_test_loss = test_running_loss / len(testloader.dataset)
        epoch_test_acc = test_correct / (test_total*256*256)
        
            
        print('epoch: ', epoch, 
              'loss: ', round(epoch_loss, 3),
              'accuracy:', round(epoch_acc, 3),
              'test_loss: ', round(epoch_test_loss, 3),
              'test_accuracy:', round(epoch_test_acc, 3)
                 )
            
        return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
    
    epochs = 40
    
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    
    for epoch in range(epochs):
        epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                     model,
                                                                     train_dl,
                                                                     test_dl)
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)
    
    # 保存模型
    # PATH = 'unet_model.pth'
    # torch.save(model.state_dict(), PATH)
    
    # 测试模型
    # my_model = Net()
    # my_model.load_state_dict(torch.load(PATH))
    # num=3  # 取出3张图片
    
    # image, mask = next(iter(test_dl))
    # pred_mask = my_model(image)
    
    # plt.figure(figsize=(10, 10))
    # for i in range(num):
    #     plt.subplot(num, 3, i*num+1)   # i从0开始, 第一行 第1张图片的原图
    #     plt.imshow(image[i].permute(1,2,0).cpu().numpy())
    #     plt.subplot(num, 3, i*num+2)   # 实际的分割图
    #     plt.imshow(mask[i].cpu().numpy())
    #     plt.subplot(num, 3, i*num+3)   # 预测出的分割图 
    #     plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())        # detach 取出实际结果
        
    # train数据集上测试
    # image, mask = next(iter(train_dl))
    # pred_mask = my_model(image)
    
    # plt.figure(figsize=(10, 10))
    # for i in range(num):
    #     plt.subplot(num, 3, i*num+1)
    #     plt.imshow(image[i].permute(1,2,0).cpu().numpy())
    #     plt.subplot(num, 3, i*num+2)
    #     plt.imshow(mask[i].cpu().numpy())
    #     plt.subplot(num, 3, i*num+3)
    #     plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
    
    

    IOU

    def fit(epoch, model, trainloader, testloader):
        correct = 0
        total = 0
        running_loss = 0
        epoch_iou = []
        
        model.train()
        for x, y in trainloader:
    #        if torch.cuda.is_available():
    #            x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                y_pred = torch.argmax(y_pred, dim=1)
                correct += (y_pred == y).sum().item()
                total += y.size(0)
                running_loss += loss.item()
                
                intersection = torch.logical_and(y, y_pred)
                union = torch.logical_or(y, y_pred)
                batch_iou = torch.true_divide(torch.sum(intersection), 
                                              torch.sum(union))
                epoch_iou.append(batch_iou)
                
        exp_lr_scheduler.step()
        epoch_loss = running_loss / len(trainloader.dataset)
        epoch_acc = correct / (total*256*256)
            
            
        test_correct = 0
        test_total = 0
        test_running_loss = 0 
        epoch_test_iou = []
        
        model.eval()
        with torch.no_grad():
            for x, y in testloader:
    #            if torch.cuda.is_available():
    #                x, y = x.to('cuda'), y.to('cuda')
                y_pred = model(x)
                loss = loss_fn(y_pred, y)
                y_pred = torch.argmax(y_pred, dim=1)
                test_correct += (y_pred == y).sum().item()
                test_total += y.size(0)
                test_running_loss += loss.item()
                intersection = torch.logical_and(y, y_pred)
                union = torch.logical_or(y, y_pred)
                batch_iou = torch.true_divide(torch.sum(intersection), 
                                              torch.sum(union))
                epoch_test_iou.append(batch_iou)
                
        
        epoch_test_loss = test_running_loss / len(testloader.dataset)
        epoch_test_acc = test_correct / (test_total*256*256)
        
            
        print('epoch: ', epoch, 
              'loss: ', round(epoch_loss, 3),
              'accuracy:', round(epoch_acc, 3),
              'IOU:', round(np.mean(epoch_iou), 3))
        print()
        print('     ', 'test_loss: ', round(epoch_test_loss, 3),
              'test_accuracy:', round(epoch_test_acc, 3),
               'test_iou:', round(np.mean(epoch_test_iou), 3)
                 )
            
        return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
    
    epochs = 40
    
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    
    for epoch in range(epochs):
        epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                     model,
                                                                     train_dl,
                                                                     test_dl)
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)
    
    # 保存模型
    PATH = 'linknet_model.pth'
    torch.save(model.state_dict(), PATH)
    
    # 测试模型
    my_model = Net()
    num=3
    image, mask = next(iter(test_dl))
    pred_mask = my_model(image)
    
    plt.figure(figsize=(10, 10))
    for i in range(num):
        plt.subplot(num, 3, i*num+1)
        plt.imshow(image[i].permute(1,2,0).cpu().numpy())
        plt.subplot(num, 3, i*num+2)
        plt.imshow(mask[i].cpu().numpy())
        plt.subplot(num, 3, i*num+3)
        plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
    
    # 在train数据上测试
    image, mask = next(iter(train_dl))
    pred_mask = my_model(image)
    
    plt.figure(figsize=(10, 10))
    for i in range(num):
        plt.subplot(num, 3, i*num+1)
        plt.imshow(image[i].permute(1,2,0).cpu().numpy())
        plt.subplot(num, 3, i*num+2)
        plt.imshow(mask[i].cpu().numpy())
        plt.subplot(num, 3, i*num+3)
        plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
    

    相关文章

      网友评论

          本文标题:PyTrch深度学习简明实战20 - 语义分割 -LinkNet

          本文链接:https://www.haomeiwen.com/subject/ejajddtx.html