美文网首页
PyTrch深度学习简明实战20 - 语义分割 -LinkNet

PyTrch深度学习简明实战20 - 语义分割 -LinkNet

作者: 薛东弗斯 | 来源:发表于2023-04-04 08:05 被阅读0次

Resnet在图像分割领域的应用- Linknet

image.png
image.png
image.png
[1707.03718] LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation (arxiv.org)
1707.03718.pdf (arxiv.org)
image.png
输出的full-conv 实际上是反卷积
输入7x7的卷积核,输入特征3,64个卷积核,/2 表示图像缩放为原来的1/2
max-pool 3x3卷积核,/2表示对图像缩放为1/2
输出full-conv 反卷积,3x3卷积核,输入64个特征,输出32个特征,2表示图像放大为之前的2倍
中间的conv 用3x3卷积核,输入32 输出32,没有对图片进行缩放,只是一种特征的提取
最后输出full-conv反卷积,2x2 kernel,输入特征32,N个filter,N的取值取决于要分为多少类别,然后
2对图片进行放大
image.png
输入 两个残差模块。
conv[(3x3),(m,n),/2] 3x3卷积核,输入m个特征,n个卷积核,图像缩放为原来的一半 image.png
image.png
image.png
image.png
image.png
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torchvision
from torchvision import transforms
import os

import glob
from PIL import Image

BATCH_SIZE = 128

# 绘制原图
# pil_img = Image.open('./data/hk/training/00001.png')
# np_img = np.array(pil_img)
# plt.imshow(np_img)
# plt.show()

# 绘制分割后的图
# pil_img = Image.open('./data/hk/training/00001_matte.png')
# np_img = np.array(pil_img)
# plt.imshow(np_img)
# plt.show()

# np_img.max(), np_img.min()   # (255, 0)
# np_img.shape    # (800, 600)
# np.unique(np_img)    .., 255])  # array([  0, .., 255])  像素点0-255直接,不是2分类的0或者1

# 绘制像素点为0/1的图片
# pil_img = Image.open('./data/hk/training/00001_matte.png')
# np_img = np.array(pil_img)
# np_img[np_img>0]=1
# plt.imshow(np_img)
# plt.show()
# np.unique(np_img)    # array([0, 1], dtype=uint8)   此时,像素只包括0和1. 这种变换对原有像素有一定的损失。

all_pics = glob.glob('./data/hk/training/*.png')
# all_pics[:5]
# ['./data/hk/training\\00001.png',
#  './data/hk/training\\00001_matte.png',
#  './data/hk/training\\00002.png',
#  './data/hk/training\\00002_matte.png',
#  './data/hk/training\\00003.png']
images = [p for p in all_pics if 'matte' not in p]
# len(images)  # 1700
annotations = [p for p in all_pics if 'matte' in p]
# len(annotations)  # 1700

np.random.seed(2021)
index = np.random.permutation(len(images))
images = np.array(images)[index]
anno = np.array(annotations)[index]
all_test_pics = glob.glob('./data/hk/testing/*.png')
test_images = [p for p in all_test_pics if 'matte' not in p]
test_anno = [p for p in all_test_pics if 'matte' in p]

transform = transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.ToTensor(),
])

class Portrait_dataset(data.Dataset):
    def __init__(self, img_paths, anno_paths):  # 需要提高图片路径+分割图路径
        self.imgs = img_paths
        self.annos = anno_paths
        
    def __getitem__(self, index):  # 切片
        img = self.imgs[index]
        anno = self.annos[index]
        
        pil_img = Image.open(img)    
        img_tensor = transform(pil_img)   # 通过transform转换为tensor。 对于原图的处理
        
        pil_anno = Image.open(anno)    
        anno_tensor = transform(pil_anno)
        anno_tensor = torch.squeeze(anno_tensor).type(torch.long)   # 默认转换后的尺寸是256*256*1,1个channel。 用squeeze去掉chanel
        anno_tensor[anno_tensor > 0] = 1      # 如果大于0,就置为1
        
        return img_tensor, anno_tensor
    
    def __len__(self):
        return len(self.imgs)
    
train_dataset = Portrait_dataset(images, anno)
test_dataset = Portrait_dataset(test_images, test_anno)

train_dl = data.DataLoader(train_dataset,
                           batch_size=BATCH_SIZE,
                           shuffle=True,
)

test_dl = data.DataLoader(
                          test_dataset,
                          batch_size=BATCH_SIZE,
)

imgs_batch, annos_batch = next(iter(train_dl))   # 取出批次数据
# img_batch.shape    # batch=8, channel=3, 大小256*256
# annos_batch.shape    # batch=8, channel=1, 大小256*256. 用torch.squeese 就把torch为1的维度去掉了

# img = imgs_batch[0].permute(1,2,0).numpy()    # 对第1张图片进行绘图。  permute将channel放到最后面
# anno = annos_batch[0].numpy()                 # anno图片没有channle这个属性,因此不需要用permute

# plt.subplot(1,2,1)                            # 绘制1行2列的第1张图
# plt.imshow(img)
# plt.subplot(1,2,2)                            # 绘制1行2列的第2张图
# plt.imshow(anno)                     

# 创建LinkNet模型
# 1. 编写卷积模块(卷积 + BN + activate)
# 2. 编写反卷积模块(反卷积+BN+activate)
# 3. 编码器(4个卷积模块)
# 4. 解码器(卷积模块+反卷积模块+卷积模块)
# 5. 实现整体网络结构

# 1. 卷积模块
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, 
                 k_size=3, 
                 stride=1,    # 通过stride来控制图片缩放
                 pad=1):      
        super(ConvBlock, self).__init__()   # 继承父类的属性
        self.conv_relu = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, 
                                      kernel_size=k_size,
                                      stride=stride,
                                      padding=pad),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True)   # 内存中不做备份,直接改变
            )
    def forward(self, x):
        x = self.conv_relu(x)
        return x

# 2. 反卷积模块
class DeconvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, 
                 k_size=3, 
                 stride=2, 
                 pad=1,
                 padding=1):
        super(DeconvBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 
                                            kernel_size=k_size,
                                            stride=stride,
                                            padding=padding,
                                            output_padding=pad)
        self.bn = nn.BatchNorm2d(out_channels)    #  反卷积需要对激活与BN层进行控制。 
            
    def forward(self, x, is_act=True):      # is_act=True 表示需要激活。 如果为False,不进行BN与激活,直接返回反卷积结构
        x = self.deconv(x)
        if is_act:
            x = torch.relu(self.bn(x))
        return x

# 3.编码器
class EncodeBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncodeBlock, self).__init__()
        self.conv1_1 = ConvBlock(in_channels, out_channels, stride=2)  # 第一个卷积层,输入m,输出n,输出需要缩放为1/2,所以stride=2
        self.conv1_2 = ConvBlock(out_channels, out_channels)           # 输入 输出相同,没有做缩放。 
        self.conv2_1 = ConvBlock(out_channels, out_channels)
        self.conv2_2 = ConvBlock(out_channels, out_channels)
        self.shortcut = ConvBlock(in_channels, out_channels, stride=2)   # shortcut连接。

    def forward(self, x):
        out1 = self.conv1_1(x)
        out1 = self.conv2_1(out1)
        residue = self.shortcut(x)
        out2 = self.conv2_1(out1 + residue)
        out2 = self.conv2_2(out2)
        return out2 + out1
    
# 4. 解码器
class DecodeBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecodeBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels, in_channels//4, 
                               k_size=1, pad=0)
        self.deconv = DeconvBlock(in_channels//4, in_channels//4)
        self.conv2 = ConvBlock(in_channels//4, out_channels, 
                               k_size=1, pad=0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.deconv(x)
        x = self.conv2(x)
        return x
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.init_conv = ConvBlock(3, 64, 
                                   k_size=7, 
                                   stride=2,
                                   pad=3)
        self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.encode1 = EncodeBlock(64, 64)
        self.encode2 = EncodeBlock(64, 128)
        self.encode3 = EncodeBlock(128, 256)
        self.encode4 = EncodeBlock(256, 512)
        
        self.decode4 = DecodeBlock(512, 256)
        self.decode3 = DecodeBlock(256, 128)
        self.decode2 = DecodeBlock(128, 64)
        self.decode1 = DecodeBlock(64, 64)
        
        self.deconv_last1 = DeconvBlock(64, 32)
        self.conv_last = ConvBlock(32, 32)
        self.deconv_last2 = DeconvBlock(32, 2, 
                                        k_size=2,
                                        pad=0,
                                        padding=0)
    
    def forward(self, x):
        x = self.init_conv(x)              #  (6, 128, 128, 64)
        x = self.init_maxpool(x)           #  (6, 64, 64, 64)
        
        e1 = self.encode1(x)               #  (6, 32, 32, 64)
        e2 = self.encode2(e1)              #  (6, 16, 16, 128)
        e3 = self.encode3(e2)              #  (6, 8, 8, 256)
        e4 = self.encode4(e3)              #  (6, 4, 4, 512)        
        
        d4 = self.decode4(e4) + e3
        d3 = self.decode3(d4) + e2
        d2 = self.decode2(d3) + e1
        d1 = self.decode1(d2)
        
        f1 = self.deconv_last1(d1)
        f2 = self.conv_last(f1)
        f3 = self.deconv_last2(f2, is_act=False)
        
        return f3
    
    
    
model = Net()

if torch.cuda.is_available():
    model.to('cuda')
    
loss_fn = nn.CrossEntropyLoss()

from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    
    model.train()
    for x, y in trainloader:
        if torch.cuda.is_available():
            x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
    exp_lr_scheduler.step()
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / (total*256*256)
        
        
    test_correct = 0
    test_total = 0
    test_running_loss = 0 
    
    model.eval()
    with torch.no_grad():
        for x, y in testloader:
            if torch.cuda.is_available():
                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / (test_total*256*256)
    
        
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
             )
        
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

epochs = 40

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

# 保存模型
# PATH = 'unet_model.pth'
# torch.save(model.state_dict(), PATH)

# 测试模型
# my_model = Net()
# my_model.load_state_dict(torch.load(PATH))
# num=3  # 取出3张图片

# image, mask = next(iter(test_dl))
# pred_mask = my_model(image)

# plt.figure(figsize=(10, 10))
# for i in range(num):
#     plt.subplot(num, 3, i*num+1)   # i从0开始, 第一行 第1张图片的原图
#     plt.imshow(image[i].permute(1,2,0).cpu().numpy())
#     plt.subplot(num, 3, i*num+2)   # 实际的分割图
#     plt.imshow(mask[i].cpu().numpy())
#     plt.subplot(num, 3, i*num+3)   # 预测出的分割图 
#     plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())        # detach 取出实际结果
    
# train数据集上测试
# image, mask = next(iter(train_dl))
# pred_mask = my_model(image)

# plt.figure(figsize=(10, 10))
# for i in range(num):
#     plt.subplot(num, 3, i*num+1)
#     plt.imshow(image[i].permute(1,2,0).cpu().numpy())
#     plt.subplot(num, 3, i*num+2)
#     plt.imshow(mask[i].cpu().numpy())
#     plt.subplot(num, 3, i*num+3)
#     plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())

IOU

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    epoch_iou = []
    
    model.train()
    for x, y in trainloader:
#        if torch.cuda.is_available():
#            x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
            
            intersection = torch.logical_and(y, y_pred)
            union = torch.logical_or(y, y_pred)
            batch_iou = torch.true_divide(torch.sum(intersection), 
                                          torch.sum(union))
            epoch_iou.append(batch_iou)
            
    exp_lr_scheduler.step()
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / (total*256*256)
        
        
    test_correct = 0
    test_total = 0
    test_running_loss = 0 
    epoch_test_iou = []
    
    model.eval()
    with torch.no_grad():
        for x, y in testloader:
#            if torch.cuda.is_available():
#                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
            intersection = torch.logical_and(y, y_pred)
            union = torch.logical_or(y, y_pred)
            batch_iou = torch.true_divide(torch.sum(intersection), 
                                          torch.sum(union))
            epoch_test_iou.append(batch_iou)
            
    
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / (test_total*256*256)
    
        
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'IOU:', round(np.mean(epoch_iou), 3))
    print()
    print('     ', 'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3),
           'test_iou:', round(np.mean(epoch_test_iou), 3)
             )
        
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

epochs = 40

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
# 保存模型
PATH = 'linknet_model.pth'
torch.save(model.state_dict(), PATH)
# 测试模型
my_model = Net()
num=3
image, mask = next(iter(test_dl))
pred_mask = my_model(image)

plt.figure(figsize=(10, 10))
for i in range(num):
    plt.subplot(num, 3, i*num+1)
    plt.imshow(image[i].permute(1,2,0).cpu().numpy())
    plt.subplot(num, 3, i*num+2)
    plt.imshow(mask[i].cpu().numpy())
    plt.subplot(num, 3, i*num+3)
    plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
# 在train数据上测试
image, mask = next(iter(train_dl))
pred_mask = my_model(image)

plt.figure(figsize=(10, 10))
for i in range(num):
    plt.subplot(num, 3, i*num+1)
    plt.imshow(image[i].permute(1,2,0).cpu().numpy())
    plt.subplot(num, 3, i*num+2)
    plt.imshow(mask[i].cpu().numpy())
    plt.subplot(num, 3, i*num+3)
    plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())

相关文章

网友评论

      本文标题:PyTrch深度学习简明实战20 - 语义分割 -LinkNet

      本文链接:https://www.haomeiwen.com/subject/ejajddtx.html