美文网首页
一文搞定UNet——图像分割(语义分割)

一文搞定UNet——图像分割(语义分割)

作者: 小黄不头秃 | 来源:发表于2023-06-14 02:29 被阅读0次

UNet被广泛的应用于图像分割(语义分割的模型),Unet 发表于 2015 年,属于 FCN 的一种变体。可以用于摇杆卫星影像的分割,工业上瑕疵划痕检测等。接下来我们来仔细讨论一下这个网络,并给出基于pytorch的代码。

一、网络结构

UNet闻如其名,整个网络架构就像是一个U字母一样。图像经过下采样,进行特征提取,再经过上采样,输出相应尺寸的图片。如下图所示:

首先我们需要知道怎么看这张图片,蓝色的长方形方块代表数据大小,方块上的数字代表通道数,左边的数字代表输入的宽高。蓝色方块之间有一个蓝色向右的箭头表示卷积层(卷积核3*3,步长为1),两个卷积层组合成了一个卷积核块(Conv_Block),其作用为使得进行特征提取,通道数增加,宽高减4。

向下的红色箭头为下采样块(Down_Sample_Block),其中组成是一个最大池化层(卷积核3*3,步长为2)。输出结果通道数不变,图像宽高减半。

绿色向上的箭头为上采样块(Up_Sample_Block),这里可以使用反卷积操作,或者使用torch.nn.functional.interpolate函数进行上采样。使得图像通道数减小的同时,宽高*2。

灰色的向右箭头是指将左边的Conv_Block输出的中间变量,与右边上采样块的输出结果进行特征融合,其具体操作为把两个特征的通道堆叠在一起。

(1)卷积层块(Conv_Block)

卷积块中包含了两个卷积层,卷积核大小为3*3,步长设置为1。


"""
卷积块(Conv_Block):
    原论文中使用3*3的卷积核,stride为1。但是在实际的代码中,我们加入padding使得特征变量的宽高不会发生改变。
"""
class Conv_Block(nn.Module):
    def __init__(self, in_C, out_c) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_C, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
            nn.Conv2d(out_c, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
        ) 

    def forward(self, x):
        return self.layers(x)
(2)下采样块(Down_Sample_Block)

下采样块中采用了一层最大池化层,卷积核3*3,步长为2。输出结果通道数不变,图像宽高减半。

"""
下采样块(Down_Sample_Block):
    使用最大池化层对图像进行特征提取。
"""
class Down_Sample_Block(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(2,2)
        ) 

    def forward(self, x):
        return self.layers(x)
(3)上采样块(Up_Sample_Block)

上采样块包括了使用卷积层对特征向量通道进行减少,上采样是下采样的反操作,其具体的作用是使得特征变量增加。通常做法有两种:

  • 反卷积:使用torch.nn.ConvTranspose2d(3,3,2,2)函数可以实现。
  • 插值法:使用torch.nn.functional.interpolate函数进行上采样。或者也可以使用nn.Upsample(scale_factor=2, mode='bilinear')完成上采样的操作。
"""
上采样块(Up_Sample_Block):
    使用插值法对特征变量进行上采样,使得宽高翻倍,通道减半
"""
class Up_Sample_Block(nn.Module):
    def __init__(self, in_c) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_c, in_c//2, 3, 1, padding=1, padding_mode="reflect", bias=False),
            nn.BatchNorm2d(in_c//2),
            nn.ReLU()
        )
        # 上采样方法1:
        self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2) 
        # 上采样方法2:
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self,x,feature):
        # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
        x = self.layers(x)
        x = self.upsample1(x)
        # 下面两行代码是将
        # resize = Resize((x.shape[2], x.shape[3]))
        # feature = resize(feature)
        res = torch.cat((x,feature),dim=1) # 包括了通道融合
        return res
(4)通道融合

首先我们可能会问,为什么要做这样的通道融合呢?其实直接对特征变量进行上采样,其实是没有特征信息增加的。在我们上采样的时候,很多的细节特征都被丢失了。所以为了能够补充更多的特征信息,UNet将前面的中间变量拼接到后面上采样的结果中,使得特征更加丰富。其实对于通道融合是有两种做法的。

  • 第一种是类似于ResNet,将两个特征变量进行相加。使得原来的变量包含更多的信息。
  • 第二种是将特征变量在通道的维度上进行拼接,使用torch.cat((x1,x2), dim=1)使得特征信息的增加。UNet选用的是这种方案。
res = torch.cat((x,feature),dim=1)
(5)输出层

输出层是一个卷积层,卷积核大小为1*1,将图像输出为我们需要的样子。

"""
输出模块:
"""
class Output(nn.Module):
    def __init__(self,in_c, out_c) -> None:
        super().__init__()
        self.layers = self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c, 1, 1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

二、代码实现

(1)网络模型
  • model.py
import torch 
import torch.nn as nn 
from torchvision.transforms import Resize

"""
卷积块(Conv_Block):
    原论文中使用3*3的卷积核,stride为1。但是在实际的代码中,我们加入padding使得特征变量的宽高不会发生改变。
"""
class Conv_Block(nn.Module):
    def __init__(self, in_C, out_c) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_C, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
            nn.Conv2d(out_c, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
        ) 

    def forward(self, x):
        return self.layers(x)

"""
下采样块(Down_Sample_Block):
    使用最大池化层对图像进行特征提取。
"""
class Down_Sample_Block(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(2,2)
        ) 

    def forward(self, x):
        return self.layers(x)
"""
上采样块(Up_Sample_Block):
    使用插值法对特征变量进行上采样,使得宽高翻倍
"""
class Up_Sample_Block(nn.Module):
    def __init__(self, in_c) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_c, in_c//2, 3, 1, padding=1, padding_mode="reflect", bias=False),
            nn.BatchNorm2d(in_c//2),
            nn.ReLU()
        )
        # 上采样方法1:
        self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2) 
        # 上采样方法2:
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self,x,feature):
        # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
        x = self.layers(x)
        x = self.upsample1(x)
        # 下面两行代码是将
        # resize = Resize((x.shape[2], x.shape[3]))
        # feature = resize(feature)
        res = torch.cat((x,feature),dim=1)
        return res
    
"""
输出模块:
"""
class Output(nn.Module):
    def __init__(self,in_c, out_c) -> None:
        super().__init__()
        self.layers = self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c, 1, 1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

"""
UNet:
    # Unet 总体分为四个模块:
    - 一个卷积模块:包含两个卷积层,从输入通道到输出通道。(原论文中有进行尺寸减小2,但是在实际的应用中我们不进行宽高减小)
    - 一个下采样模块:maxpooling,宽高减半,加上一个卷积层。后接卷积模块。
    - 上采样:使用插值法将宽高增加,在将对应的下采样的特征图进行信息补充。
        # 可以通过类似残差模块的相加,也可以使用路由
    - 输出:1*1卷积:改变通道数量
"""
class UNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.down = Down_Sample_Block()

        self.conv1 = Conv_Block(1,64)
        self.conv2 = Conv_Block(64,128)
        self.conv3 = Conv_Block(128,256)
        self.conv4 = Conv_Block(256,512)
        self.conv5 = Conv_Block(512,1024)
        self.conv6 = Conv_Block(1024,512)
        self.conv7 = Conv_Block(512,256)
        self.conv8 = Conv_Block(256,128)
        self.conv9 = Conv_Block(128,64)
        self.up1 = Up_Sample_Block(1024)
        self.up2 = Up_Sample_Block(512)
        self.up3 = Up_Sample_Block(256)
        self.up4 = Up_Sample_Block(128)
        self.out = Output(64,1)

    def forward(self,x):
        out1 = self.conv1(x)
        out2 = self.conv2(self.down(out1))
        out3 = self.conv3(self.down(out2))
        out4 = self.conv4(self.down(out3)) # [1, 512, 28, 28]
        out5 = self.conv5(self.down(out4)) # [1, 1024, 14, 14]
        out6 = self.conv6(self.up1(out5,out4)) # [1, 512, 28, 28]
        out7 = self.conv7(self.up2(out6,out3)) # [1, 256, 56, 56]
        out8 = self.conv8(self.up3(out7,out2)) # [1, 128, 112, 112]
        out9 = self.conv9(self.up4(out8,out1)) # [1, 64, 224, 224]
        out = self.out(out9)
        return out

if __name__ == "__main__":
    x = torch.randn((1,1,224,224))
    net = UNet()
    y = net(x)
    print(y.shape)
    
(2)数据集:
(3)训练结果:
  • train.py:
import PIL.Image as Image
import numpy as np 
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import os 
from model import UNet
from torchvision.utils import save_image

IMG_PATH = "./data/FundusVessels/JPEGImages/"
TARGET_PATH = "./data/FundusVessels/Annotations/"
DST_DIR = "./img"

class EYE_Dataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.img_list =  os.listdir(IMG_PATH)
        self.target_list = os.listdir(TARGET_PATH)

    def __len__(self) -> int:
        return len(self.img_list)
    
    def __getitem__(self, index):
        name = self.img_list[index][0:-4]
        img = Image.open(IMG_PATH+f"{name}.jpg").convert("L").resize((224,224))
        lable = Image.open(TARGET_PATH+f"{name}.png").resize((224,224))
        img = np.array(img, dtype=np.float32)/255
        img = img[np.newaxis,:]
        lable = np.array(lable, dtype=np.float32)
        return img, lable
    


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_dataset = EYE_Dataset()
    train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True)

    net = UNet().to(device)
    if os.path.exists("./UNet.pt"):
        params = torch.load("./UNet.pt")
        net.load_state_dict(params)

    optim = torch.optim.Adam(net.parameters())
    loss_fn = nn.BCELoss()
    epoch=1
    net.train()
    while True:
        for i,(img,target) in enumerate(train_loader):
            img, target = img.to(device), target.to(device)
            y = net(img)
            loss = loss_fn(y, target.unsqueeze(dim=0))
            optim.zero_grad()
            loss.backward()
            optim.step()

            if i%100 == 0:
                # 保存测试结果
                img2 =  (y[0]>0.6).float() *255
                res = torch.stack([img[0],img2],dim=0)
                save_image(res.cpu(), DST_DIR + f"/epoch{epoch}_{i}.jpg", nrow=2)
                print(f"epoch {epoch},loss: {loss.item()}")
                torch.save(net.state_dict(),"./UNet.pt")
            epoch += 1

大概训练了1050轮

相关文章

  • Unet图像分割

    Unet网络是一种图像语义分割网络,图像语义分割网络让计算机根据图像的语义来进行分割,例如让计算机在输入下面下图,...

  • 图像语义分割基础知识整理(CNN,FCN,Unet,CVPR-D

    整理的一些关于图像语义分割,CNN,FCN,Unet等的基础知识同时对CVPR-DeepGlobe路网分割竞赛部分...

  • 图像分割常用指标及MIoU计算

    一、语义图像分割结果定义 语义图像分割结果,可分为True Positive、False Positive、Tru...

  • 【转载】语义分割中的深度学习方法全解:从FCN、SegNet到各

    http://www.sohu.com/a/155907339_610300 图像语义分割就是机器自动从图像中分割...

  • 计算机视觉-语义分割

    什么是语义分割 我们都知道,图像是由许多像素(Pixel)组成,而「语义分割」顾名思义就是将像素按照图像中表达语义...

  • 我所了解的图像分割

    图像分割是我大二2019年做的东西,这篇文章用来总结。 一、什么是图像分割 分语义【像素级别图像】,实例【分割物体...

  • 语义图像分割概览

    摘要:本文讨论如何利用卷积神经网络进行语义图像分割的任务。 语义图像分割,目标是将图像的每个像素标记为所表示的相关...

  • 图像分割

    一、语义分割  图像语义分割的目标是标记图像每个像素的类别。因为我们需要预测图像中的每个像素,所以此任务通常被称为...

  • 自动驾驶入门日记-4-图像语义分割

    接着介绍图像语义分割中解决分割精细度的另一种解决方案--多尺度特征融合。在基于FCNs的语义分割架构中,如何有效的...

  • 基于深度学习的语义分割技术讲解

    在计算机视觉领域,有一个方向是语义分割,一般是针对图像进行像素级分类。具体而言,语义图像分割就是将每个像素都标注上...

网友评论

      本文标题:一文搞定UNet——图像分割(语义分割)

      本文链接:https://www.haomeiwen.com/subject/cnrpedtx.html