美文网首页
一文搞定UNet——图像分割(语义分割)

一文搞定UNet——图像分割(语义分割)

作者: 小黄不头秃 | 来源:发表于2023-06-14 02:29 被阅读0次

    UNet被广泛的应用于图像分割(语义分割的模型),Unet 发表于 2015 年,属于 FCN 的一种变体。可以用于摇杆卫星影像的分割,工业上瑕疵划痕检测等。接下来我们来仔细讨论一下这个网络,并给出基于pytorch的代码。

    一、网络结构

    UNet闻如其名,整个网络架构就像是一个U字母一样。图像经过下采样,进行特征提取,再经过上采样,输出相应尺寸的图片。如下图所示:

    首先我们需要知道怎么看这张图片,蓝色的长方形方块代表数据大小,方块上的数字代表通道数,左边的数字代表输入的宽高。蓝色方块之间有一个蓝色向右的箭头表示卷积层(卷积核3*3,步长为1),两个卷积层组合成了一个卷积核块(Conv_Block),其作用为使得进行特征提取,通道数增加,宽高减4。

    向下的红色箭头为下采样块(Down_Sample_Block),其中组成是一个最大池化层(卷积核3*3,步长为2)。输出结果通道数不变,图像宽高减半。

    绿色向上的箭头为上采样块(Up_Sample_Block),这里可以使用反卷积操作,或者使用torch.nn.functional.interpolate函数进行上采样。使得图像通道数减小的同时,宽高*2。

    灰色的向右箭头是指将左边的Conv_Block输出的中间变量,与右边上采样块的输出结果进行特征融合,其具体操作为把两个特征的通道堆叠在一起。

    (1)卷积层块(Conv_Block)

    卷积块中包含了两个卷积层,卷积核大小为3*3,步长设置为1。


    """
    卷积块(Conv_Block):
        原论文中使用3*3的卷积核,stride为1。但是在实际的代码中,我们加入padding使得特征变量的宽高不会发生改变。
    """
    class Conv_Block(nn.Module):
        def __init__(self, in_C, out_c) -> None:
            super().__init__()
            self.layers = nn.Sequential(
                nn.Conv2d(in_C, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
                nn.Conv2d(out_c, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
            ) 
    
        def forward(self, x):
            return self.layers(x)
    
    (2)下采样块(Down_Sample_Block)

    下采样块中采用了一层最大池化层,卷积核3*3,步长为2。输出结果通道数不变,图像宽高减半。

    """
    下采样块(Down_Sample_Block):
        使用最大池化层对图像进行特征提取。
    """
    class Down_Sample_Block(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.layers = nn.Sequential(
                nn.MaxPool2d(2,2)
            ) 
    
        def forward(self, x):
            return self.layers(x)
    
    (3)上采样块(Up_Sample_Block)

    上采样块包括了使用卷积层对特征向量通道进行减少,上采样是下采样的反操作,其具体的作用是使得特征变量增加。通常做法有两种:

    • 反卷积:使用torch.nn.ConvTranspose2d(3,3,2,2)函数可以实现。
    • 插值法:使用torch.nn.functional.interpolate函数进行上采样。或者也可以使用nn.Upsample(scale_factor=2, mode='bilinear')完成上采样的操作。
    """
    上采样块(Up_Sample_Block):
        使用插值法对特征变量进行上采样,使得宽高翻倍,通道减半
    """
    class Up_Sample_Block(nn.Module):
        def __init__(self, in_c) -> None:
            super().__init__()
            self.layers = nn.Sequential(
                nn.Conv2d(in_c, in_c//2, 3, 1, padding=1, padding_mode="reflect", bias=False),
                nn.BatchNorm2d(in_c//2),
                nn.ReLU()
            )
            # 上采样方法1:
            self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2) 
            # 上采样方法2:
            self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear')
    
        def forward(self,x,feature):
            # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
            x = self.layers(x)
            x = self.upsample1(x)
            # 下面两行代码是将
            # resize = Resize((x.shape[2], x.shape[3]))
            # feature = resize(feature)
            res = torch.cat((x,feature),dim=1) # 包括了通道融合
            return res
    
    (4)通道融合

    首先我们可能会问,为什么要做这样的通道融合呢?其实直接对特征变量进行上采样,其实是没有特征信息增加的。在我们上采样的时候,很多的细节特征都被丢失了。所以为了能够补充更多的特征信息,UNet将前面的中间变量拼接到后面上采样的结果中,使得特征更加丰富。其实对于通道融合是有两种做法的。

    • 第一种是类似于ResNet,将两个特征变量进行相加。使得原来的变量包含更多的信息。
    • 第二种是将特征变量在通道的维度上进行拼接,使用torch.cat((x1,x2), dim=1)使得特征信息的增加。UNet选用的是这种方案。
    res = torch.cat((x,feature),dim=1)
    
    (5)输出层

    输出层是一个卷积层,卷积核大小为1*1,将图像输出为我们需要的样子。

    """
    输出模块:
    """
    class Output(nn.Module):
        def __init__(self,in_c, out_c) -> None:
            super().__init__()
            self.layers = self.layers = nn.Sequential(
                nn.Conv2d(in_c, out_c, 1, 1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU(),
                nn.Sigmoid()
            )
    
        def forward(self, x):
            return self.layers(x)
    

    二、代码实现

    (1)网络模型
    • model.py
    import torch 
    import torch.nn as nn 
    from torchvision.transforms import Resize
    
    """
    卷积块(Conv_Block):
        原论文中使用3*3的卷积核,stride为1。但是在实际的代码中,我们加入padding使得特征变量的宽高不会发生改变。
    """
    class Conv_Block(nn.Module):
        def __init__(self, in_C, out_c) -> None:
            super().__init__()
            self.layers = nn.Sequential(
                nn.Conv2d(in_C, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
                nn.Conv2d(out_c, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
            ) 
    
        def forward(self, x):
            return self.layers(x)
    
    """
    下采样块(Down_Sample_Block):
        使用最大池化层对图像进行特征提取。
    """
    class Down_Sample_Block(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.layers = nn.Sequential(
                nn.MaxPool2d(2,2)
            ) 
    
        def forward(self, x):
            return self.layers(x)
    """
    上采样块(Up_Sample_Block):
        使用插值法对特征变量进行上采样,使得宽高翻倍
    """
    class Up_Sample_Block(nn.Module):
        def __init__(self, in_c) -> None:
            super().__init__()
            self.layers = nn.Sequential(
                nn.Conv2d(in_c, in_c//2, 3, 1, padding=1, padding_mode="reflect", bias=False),
                nn.BatchNorm2d(in_c//2),
                nn.ReLU()
            )
            # 上采样方法1:
            self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2) 
            # 上采样方法2:
            self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear')
    
        def forward(self,x,feature):
            # 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
            x = self.layers(x)
            x = self.upsample1(x)
            # 下面两行代码是将
            # resize = Resize((x.shape[2], x.shape[3]))
            # feature = resize(feature)
            res = torch.cat((x,feature),dim=1)
            return res
        
    """
    输出模块:
    """
    class Output(nn.Module):
        def __init__(self,in_c, out_c) -> None:
            super().__init__()
            self.layers = self.layers = nn.Sequential(
                nn.Conv2d(in_c, out_c, 1, 1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU(),
                nn.Sigmoid()
            )
    
        def forward(self, x):
            return self.layers(x)
    
    """
    UNet:
        # Unet 总体分为四个模块:
        - 一个卷积模块:包含两个卷积层,从输入通道到输出通道。(原论文中有进行尺寸减小2,但是在实际的应用中我们不进行宽高减小)
        - 一个下采样模块:maxpooling,宽高减半,加上一个卷积层。后接卷积模块。
        - 上采样:使用插值法将宽高增加,在将对应的下采样的特征图进行信息补充。
            # 可以通过类似残差模块的相加,也可以使用路由
        - 输出:1*1卷积:改变通道数量
    """
    class UNet(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.down = Down_Sample_Block()
    
            self.conv1 = Conv_Block(1,64)
            self.conv2 = Conv_Block(64,128)
            self.conv3 = Conv_Block(128,256)
            self.conv4 = Conv_Block(256,512)
            self.conv5 = Conv_Block(512,1024)
            self.conv6 = Conv_Block(1024,512)
            self.conv7 = Conv_Block(512,256)
            self.conv8 = Conv_Block(256,128)
            self.conv9 = Conv_Block(128,64)
            self.up1 = Up_Sample_Block(1024)
            self.up2 = Up_Sample_Block(512)
            self.up3 = Up_Sample_Block(256)
            self.up4 = Up_Sample_Block(128)
            self.out = Output(64,1)
    
        def forward(self,x):
            out1 = self.conv1(x)
            out2 = self.conv2(self.down(out1))
            out3 = self.conv3(self.down(out2))
            out4 = self.conv4(self.down(out3)) # [1, 512, 28, 28]
            out5 = self.conv5(self.down(out4)) # [1, 1024, 14, 14]
            out6 = self.conv6(self.up1(out5,out4)) # [1, 512, 28, 28]
            out7 = self.conv7(self.up2(out6,out3)) # [1, 256, 56, 56]
            out8 = self.conv8(self.up3(out7,out2)) # [1, 128, 112, 112]
            out9 = self.conv9(self.up4(out8,out1)) # [1, 64, 224, 224]
            out = self.out(out9)
            return out
    
    if __name__ == "__main__":
        x = torch.randn((1,1,224,224))
        net = UNet()
        y = net(x)
        print(y.shape)
        
    
    (2)数据集:
    (3)训练结果:
    • train.py:
    import PIL.Image as Image
    import numpy as np 
    import torch 
    import torch.nn as nn
    from torch.utils.data import DataLoader, Dataset
    import os 
    from model import UNet
    from torchvision.utils import save_image
    
    IMG_PATH = "./data/FundusVessels/JPEGImages/"
    TARGET_PATH = "./data/FundusVessels/Annotations/"
    DST_DIR = "./img"
    
    class EYE_Dataset(Dataset):
        def __init__(self) -> None:
            super().__init__()
            self.img_list =  os.listdir(IMG_PATH)
            self.target_list = os.listdir(TARGET_PATH)
    
        def __len__(self) -> int:
            return len(self.img_list)
        
        def __getitem__(self, index):
            name = self.img_list[index][0:-4]
            img = Image.open(IMG_PATH+f"{name}.jpg").convert("L").resize((224,224))
            lable = Image.open(TARGET_PATH+f"{name}.png").resize((224,224))
            img = np.array(img, dtype=np.float32)/255
            img = img[np.newaxis,:]
            lable = np.array(lable, dtype=np.float32)
            return img, lable
        
    
    
    if __name__ == "__main__":
        device = "cuda" if torch.cuda.is_available() else "cpu"
        train_dataset = EYE_Dataset()
        train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True)
    
        net = UNet().to(device)
        if os.path.exists("./UNet.pt"):
            params = torch.load("./UNet.pt")
            net.load_state_dict(params)
    
        optim = torch.optim.Adam(net.parameters())
        loss_fn = nn.BCELoss()
        epoch=1
        net.train()
        while True:
            for i,(img,target) in enumerate(train_loader):
                img, target = img.to(device), target.to(device)
                y = net(img)
                loss = loss_fn(y, target.unsqueeze(dim=0))
                optim.zero_grad()
                loss.backward()
                optim.step()
    
                if i%100 == 0:
                    # 保存测试结果
                    img2 =  (y[0]>0.6).float() *255
                    res = torch.stack([img[0],img2],dim=0)
                    save_image(res.cpu(), DST_DIR + f"/epoch{epoch}_{i}.jpg", nrow=2)
                    print(f"epoch {epoch},loss: {loss.item()}")
                    torch.save(net.state_dict(),"./UNet.pt")
                epoch += 1
    

    大概训练了1050轮

    相关文章

      网友评论

          本文标题:一文搞定UNet——图像分割(语义分割)

          本文链接:https://www.haomeiwen.com/subject/cnrpedtx.html