UNet被广泛的应用于图像分割(语义分割的模型),Unet 发表于 2015 年,属于 FCN 的一种变体。可以用于摇杆卫星影像的分割,工业上瑕疵划痕检测等。接下来我们来仔细讨论一下这个网络,并给出基于pytorch的代码。
一、网络结构
UNet闻如其名,整个网络架构就像是一个U字母一样。图像经过下采样,进行特征提取,再经过上采样,输出相应尺寸的图片。如下图所示:
首先我们需要知道怎么看这张图片,蓝色的长方形方块代表数据大小,方块上的数字代表通道数,左边的数字代表输入的宽高。蓝色方块之间有一个蓝色向右的箭头表示卷积层(卷积核3*3,步长为1),两个卷积层组合成了一个卷积核块(Conv_Block),其作用为使得进行特征提取,通道数增加,宽高减4。
向下的红色箭头为下采样块(Down_Sample_Block),其中组成是一个最大池化层(卷积核3*3,步长为2)。输出结果通道数不变,图像宽高减半。
绿色向上的箭头为上采样块(Up_Sample_Block),这里可以使用反卷积操作,或者使用torch.nn.functional.interpolate
函数进行上采样。使得图像通道数减小的同时,宽高*2。
灰色的向右箭头是指将左边的Conv_Block输出的中间变量,与右边上采样块的输出结果进行特征融合,其具体操作为把两个特征的通道堆叠在一起。
(1)卷积层块(Conv_Block)
卷积块中包含了两个卷积层,卷积核大小为3*3,步长设置为1。
"""
卷积块(Conv_Block):
原论文中使用3*3的卷积核,stride为1。但是在实际的代码中,我们加入padding使得特征变量的宽高不会发生改变。
"""
class Conv_Block(nn.Module):
def __init__(self, in_C, out_c) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_C, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
nn.Conv2d(out_c, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
)
def forward(self, x):
return self.layers(x)
(2)下采样块(Down_Sample_Block)
下采样块中采用了一层最大池化层,卷积核3*3,步长为2。输出结果通道数不变,图像宽高减半。
"""
下采样块(Down_Sample_Block):
使用最大池化层对图像进行特征提取。
"""
class Down_Sample_Block(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.MaxPool2d(2,2)
)
def forward(self, x):
return self.layers(x)
(3)上采样块(Up_Sample_Block)
上采样块包括了使用卷积层对特征向量通道进行减少,上采样是下采样的反操作,其具体的作用是使得特征变量增加。通常做法有两种:
- 反卷积:使用
torch.nn.ConvTranspose2d(3,3,2,2)
函数可以实现。 - 插值法:使用
torch.nn.functional.interpolate
函数进行上采样。或者也可以使用nn.Upsample(scale_factor=2, mode='bilinear')
完成上采样的操作。
"""
上采样块(Up_Sample_Block):
使用插值法对特征变量进行上采样,使得宽高翻倍,通道减半
"""
class Up_Sample_Block(nn.Module):
def __init__(self, in_c) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_c, in_c//2, 3, 1, padding=1, padding_mode="reflect", bias=False),
nn.BatchNorm2d(in_c//2),
nn.ReLU()
)
# 上采样方法1:
self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2)
# 上采样方法2:
self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x,feature):
# 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
x = self.layers(x)
x = self.upsample1(x)
# 下面两行代码是将
# resize = Resize((x.shape[2], x.shape[3]))
# feature = resize(feature)
res = torch.cat((x,feature),dim=1) # 包括了通道融合
return res
(4)通道融合
首先我们可能会问,为什么要做这样的通道融合呢?其实直接对特征变量进行上采样,其实是没有特征信息增加的。在我们上采样的时候,很多的细节特征都被丢失了。所以为了能够补充更多的特征信息,UNet将前面的中间变量拼接到后面上采样的结果中,使得特征更加丰富。其实对于通道融合是有两种做法的。
- 第一种是类似于ResNet,将两个特征变量进行相加。使得原来的变量包含更多的信息。
- 第二种是将特征变量在通道的维度上进行拼接,使用
torch.cat((x1,x2), dim=1)
使得特征信息的增加。UNet选用的是这种方案。
res = torch.cat((x,feature),dim=1)
(5)输出层
输出层是一个卷积层,卷积核大小为1*1,将图像输出为我们需要的样子。
"""
输出模块:
"""
class Output(nn.Module):
def __init__(self,in_c, out_c) -> None:
super().__init__()
self.layers = self.layers = nn.Sequential(
nn.Conv2d(in_c, out_c, 1, 1, bias=False),
nn.BatchNorm2d(out_c),
nn.ReLU(),
nn.Sigmoid()
)
def forward(self, x):
return self.layers(x)
二、代码实现
(1)网络模型
- model.py
import torch
import torch.nn as nn
from torchvision.transforms import Resize
"""
卷积块(Conv_Block):
原论文中使用3*3的卷积核,stride为1。但是在实际的代码中,我们加入padding使得特征变量的宽高不会发生改变。
"""
class Conv_Block(nn.Module):
def __init__(self, in_C, out_c) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_C, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
nn.Conv2d(out_c, out_c, 3, 1, 1, padding_mode="reflect",bias=False),nn.BatchNorm2d(out_c),nn.ReLU(),
)
def forward(self, x):
return self.layers(x)
"""
下采样块(Down_Sample_Block):
使用最大池化层对图像进行特征提取。
"""
class Down_Sample_Block(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.MaxPool2d(2,2)
)
def forward(self, x):
return self.layers(x)
"""
上采样块(Up_Sample_Block):
使用插值法对特征变量进行上采样,使得宽高翻倍
"""
class Up_Sample_Block(nn.Module):
def __init__(self, in_c) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_c, in_c//2, 3, 1, padding=1, padding_mode="reflect", bias=False),
nn.BatchNorm2d(in_c//2),
nn.ReLU()
)
# 上采样方法1:
self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2)
# 上采样方法2:
self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x,feature):
# 方法3:x = torch.nn.functional.interpolate(input=x,scale_factor=2, mode="nearest")
x = self.layers(x)
x = self.upsample1(x)
# 下面两行代码是将
# resize = Resize((x.shape[2], x.shape[3]))
# feature = resize(feature)
res = torch.cat((x,feature),dim=1)
return res
"""
输出模块:
"""
class Output(nn.Module):
def __init__(self,in_c, out_c) -> None:
super().__init__()
self.layers = self.layers = nn.Sequential(
nn.Conv2d(in_c, out_c, 1, 1, bias=False),
nn.BatchNorm2d(out_c),
nn.ReLU(),
nn.Sigmoid()
)
def forward(self, x):
return self.layers(x)
"""
UNet:
# Unet 总体分为四个模块:
- 一个卷积模块:包含两个卷积层,从输入通道到输出通道。(原论文中有进行尺寸减小2,但是在实际的应用中我们不进行宽高减小)
- 一个下采样模块:maxpooling,宽高减半,加上一个卷积层。后接卷积模块。
- 上采样:使用插值法将宽高增加,在将对应的下采样的特征图进行信息补充。
# 可以通过类似残差模块的相加,也可以使用路由
- 输出:1*1卷积:改变通道数量
"""
class UNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.down = Down_Sample_Block()
self.conv1 = Conv_Block(1,64)
self.conv2 = Conv_Block(64,128)
self.conv3 = Conv_Block(128,256)
self.conv4 = Conv_Block(256,512)
self.conv5 = Conv_Block(512,1024)
self.conv6 = Conv_Block(1024,512)
self.conv7 = Conv_Block(512,256)
self.conv8 = Conv_Block(256,128)
self.conv9 = Conv_Block(128,64)
self.up1 = Up_Sample_Block(1024)
self.up2 = Up_Sample_Block(512)
self.up3 = Up_Sample_Block(256)
self.up4 = Up_Sample_Block(128)
self.out = Output(64,1)
def forward(self,x):
out1 = self.conv1(x)
out2 = self.conv2(self.down(out1))
out3 = self.conv3(self.down(out2))
out4 = self.conv4(self.down(out3)) # [1, 512, 28, 28]
out5 = self.conv5(self.down(out4)) # [1, 1024, 14, 14]
out6 = self.conv6(self.up1(out5,out4)) # [1, 512, 28, 28]
out7 = self.conv7(self.up2(out6,out3)) # [1, 256, 56, 56]
out8 = self.conv8(self.up3(out7,out2)) # [1, 128, 112, 112]
out9 = self.conv9(self.up4(out8,out1)) # [1, 64, 224, 224]
out = self.out(out9)
return out
if __name__ == "__main__":
x = torch.randn((1,1,224,224))
net = UNet()
y = net(x)
print(y.shape)
(2)数据集:
- VOC数据集:里面包含了图像分割数据,但训练效果并不好。
- 眼球血管分割数据集:可以使用这个开源数据集练练手,这里也给出链接:眼底血管分割 - 飞桨AI Studio (baidu.com)
(3)训练结果:
- train.py:
import PIL.Image as Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import os
from model import UNet
from torchvision.utils import save_image
IMG_PATH = "./data/FundusVessels/JPEGImages/"
TARGET_PATH = "./data/FundusVessels/Annotations/"
DST_DIR = "./img"
class EYE_Dataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.img_list = os.listdir(IMG_PATH)
self.target_list = os.listdir(TARGET_PATH)
def __len__(self) -> int:
return len(self.img_list)
def __getitem__(self, index):
name = self.img_list[index][0:-4]
img = Image.open(IMG_PATH+f"{name}.jpg").convert("L").resize((224,224))
lable = Image.open(TARGET_PATH+f"{name}.png").resize((224,224))
img = np.array(img, dtype=np.float32)/255
img = img[np.newaxis,:]
lable = np.array(lable, dtype=np.float32)
return img, lable
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataset = EYE_Dataset()
train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True)
net = UNet().to(device)
if os.path.exists("./UNet.pt"):
params = torch.load("./UNet.pt")
net.load_state_dict(params)
optim = torch.optim.Adam(net.parameters())
loss_fn = nn.BCELoss()
epoch=1
net.train()
while True:
for i,(img,target) in enumerate(train_loader):
img, target = img.to(device), target.to(device)
y = net(img)
loss = loss_fn(y, target.unsqueeze(dim=0))
optim.zero_grad()
loss.backward()
optim.step()
if i%100 == 0:
# 保存测试结果
img2 = (y[0]>0.6).float() *255
res = torch.stack([img[0],img2],dim=0)
save_image(res.cpu(), DST_DIR + f"/epoch{epoch}_{i}.jpg", nrow=2)
print(f"epoch {epoch},loss: {loss.item()}")
torch.save(net.state_dict(),"./UNet.pt")
epoch += 1
大概训练了1050轮
网友评论