Unet图像分割

作者: AsdilFibrizo | 来源:发表于2019-12-16 11:55 被阅读0次
Unet网络是一种图像语义分割网络,图像语义分割网络让计算机根据图像的语义来进行分割,例如让计算机在输入下面下图,能够输出指定分割的图片。
基本图片分割

原图中,物体被分为三类,1.背景, 2.人, 3.自行车

地理信息

语义分割的用处很多,比如说上图中分割卫星图,通过多伦迭代,Prediction逐渐与Grond Truth一致。

Unet网络结构
Unet网络结构如下,整个网络形如字母U。简单的来说,整个网络分为两个部分,左边部分负责特征提取,随着网络层加深,网络的channel逐渐变大,"图片"逐渐变小。右边的网络负责特征的还原,整个网络实际上就是一个编码-解码器。需要注意的是,整个网络最出彩的地方是灰色箭头的部分。在编码的过程中,部分信息丢失了(Maxpooling和Conv2D)。在解码时,加入与之对应的编码层信息。从图上来看的话就是右边每一层网络都加入了一部分"白"色的"图片"(特征)。 全连接语义分割

那么这里就有个问题,为什么要这么复杂的做一个编码-解码器?上图的一个简单的多层卷积就可以完成图像语义分割。


编码解码器

原因就在于随着卷积核的越大,伴随着参数就会成倍增长,一是运算效率会大大下降,其次不利于收敛。这里强烈推荐看一篇文章“看懂”卷积神经网(Visualizing and Understanding Convolutional Networks)

工作原理1

这里讲一下,Unet工作原理,假设我们有一张图片,如左图所示,我们会根据实际需要将需要识别的区域转化为特定的"编码"作为类标签。


工作原理2
工作原理3

实际上每个需要识别的物体需要一个channel,有多少个需要识别的物体,就有多少个输出channel,最后再做一个叠加就是最终我们想分割的结果。

下面哪一个简单的实例代码来说明Unet的工作原理,源代码Github在这里,下面我做一些解释性说明

1.首先引入必要包
%matplotlib inline
%load_ext autoreload
%autoreload 2
import os, sys
import random
import copy
import itertools
import time
from functools import reduce
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
from torchsummary import summary
2.生成模拟数据,这一部分不用太纠结代码,复制粘贴就可以
def generate_random_data(height, width, count):
    x, y = zip(*[generate_img_and_mask(height, width) for i in range(0, count)])
    X = np.asarray(x) * 255
    X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
    Y = np.asarray(y)
    return X, Y

def generate_img_and_mask(height, width):
    shape = (height, width)
    triangle_location = get_random_location(*shape)
    circle_location1 = get_random_location(*shape, zoom=0.7)
    circle_location2 = get_random_location(*shape, zoom=0.5)
    mesh_location = get_random_location(*shape)
    square_location = get_random_location(*shape, zoom=0.8)
    plus_location = get_random_location(*shape, zoom=1.2)

    # Create input image
    arr = np.zeros(shape, dtype=bool)
    arr = add_triangle(arr, *triangle_location)
    arr = add_circle(arr, *circle_location1)
    arr = add_circle(arr, *circle_location2, fill=True)
    arr = add_mesh_square(arr, *mesh_location)
    arr = add_filled_square(arr, *square_location)
    arr = add_plus(arr, *plus_location)
    arr = np.reshape(arr, (1, height, width)).astype(np.float32)

    # Create target masks
    masks = np.asarray([
        add_filled_square(np.zeros(shape, dtype=bool), *square_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location2, fill=True),
        add_triangle(np.zeros(shape, dtype=bool), *triangle_location),
        add_circle(np.zeros(shape, dtype=bool), *circle_location1),
         add_filled_square(np.zeros(shape, dtype=bool), *mesh_location),
        # add_mesh_square(np.zeros(shape, dtype=bool), *mesh_location),
        add_plus(np.zeros(shape, dtype=bool), *plus_location)
    ]).astype(np.float32)
    return arr, masks

def add_square(arr, x, y, size):
    s = int(size / 2)
    arr[x-s,y-s:y+s] = True
    arr[x+s,y-s:y+s] = True
    arr[x-s:x+s,y-s] = True
    arr[x-s:x+s,y+s] = True
    return arr

def add_filled_square(arr, x, y, size):
    s = int(size / 2)
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s]))

def logical_and(arrays):
    new_array = np.ones(arrays[0].shape, dtype=bool)
    for a in arrays:
        new_array = np.logical_and(new_array, a)
    return new_array

def add_mesh_square(arr, x, y, size):
    s = int(size / 2)
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1]))

def add_triangle(arr, x, y, size):
    s = int(size / 2)
    triangle = np.tril(np.ones((size, size), dtype=bool))
    arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle
    return arr

def add_circle(arr, x, y, size, fill=False):
    xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
    circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
    new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True))
    return new_arr

def add_plus(arr, x, y, size):
    s = int(size / 2)
    arr[x-1:x+1,y-s:y+s] = True
    arr[x-s:x+s,y-1:y+1] = True
    return arr

def get_random_location(width, height, zoom=1.0):
    x = int(width * random.uniform(0.1, 0.9))
    y = int(height * random.uniform(0.1, 0.9))
    size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)
    return (x, y, size)

def plot_img_array(img_array, ncol=3):
    nrow = len(img_array) // ncol
    f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))
    for i in range(len(img_array)):
        plots[i // ncol, i % ncol]
        plots[i // ncol, i % ncol].imshow(img_array[i])

def plot_side_by_side(img_arrays):
    flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))
    plot_img_array(np.array(flatten_list), ncol=len(img_arrays))

def plot_errors(results_dict, title):
    markers = itertools.cycle(('+', 'x', 'o'))
    plt.title('{}'.format(title))
    for label, result in sorted(results_dict.items()):
        plt.plot(result, marker=next(markers), label=label)
        plt.ylabel('dice_coef')
        plt.xlabel('epoch')
        plt.legend(loc=3, bbox_to_anchor=(1, 0))
    plt.show()

def masks_to_colorimg(masks):
    colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])
    colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    channels, height, width = masks.shape
    for y in range(height):
        for x in range(width):
            selected_colors = colors[masks[:,y,x] > 0.5]
            if len(selected_colors) > 0:
                colorimg[y,x,:] = np.mean(selected_colors, axis=0)
    return colorimg.astype(np.uint8)
3.看一下输入数据和类标签数据
# 生成图片与类标签(192*192, 3张)
input_images, target_masks = generate_random_data(192, 192, count=1)
print(f'输入数据维度:{input_images.shape}')
print(f'输出数据维度:{target_masks.shape}')
# 修改数据类型,方便画图
input_images_rgb = [x.astype(np.uint8) for x in input_images]
# 将灰度图片(channel=1)变为RGB图片(channel=3)
target_masks_rgb = [masks_to_colorimg(x) for x in target_masks]
# 显示模拟图片
plot_side_by_side([input_images_rgb, target_masks_rgb])

['out']:输入数据维度:(1, 192, 192, 3)
['out']:输出数据维度:(1, 6, 192, 192)

训练数据一个(192,192,3(RGB通道))的RGB图片, 类标签数据是一组灰度图片(6,192,192),每个需要识别的图形是一个灰度图片一共6个图形。
模拟数据

左图为输入数据,右图中将类标签灰度图片加了RBG通道,然后6张图叠加的效果图(我们只需预测6张灰度图即可)。

4.数据生成器
# 一个简单的pytorch数据迭代器
class SimDataset(Dataset):
    def __init__(self, count, transform=None):
        # count:每次需要生成的数据量
        # transform指定数据转化器
        self.input_images, self.target_masks = generate_random_data(192, 192, count=count)        
        self.transform = transform

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx):
        image = self.input_images[idx]
        mask = self.target_masks[idx]
        if self.transform:
            image = self.transform(image)
        return [image, mask]
# use same transform for train/val for this example
trans = transforms.Compose([
    transforms.ToTensor(),
])
# 这里生成2000组模拟数据作为训练集, 200组模拟数据作为测试集
train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)
batch_size = 25
dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

Unet网络

Unet编码层
# Unet编码层, 如上图所示,包含两个(卷积+Relu)
# 原始Unet网络中padding=0(填充),所以"图片"会变小
# 572*572--->570*570--->568*568
def double_conv(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, 3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_channels, out_channels, 3, padding=1),
      nn.ReLU(inplace=True)
  )
Unet编码层2
Unet解码层1
5.定义网络
# Unet经过一次double_conv通道数加倍(变厚),然后使用Maxpool, "图片"维度/2(变小)
class Unet(nn.Module):
  def __init__(self, n_class):
    super().__init__()
    self.dconv_down1 = double_conv(3, 64)
    self.dconv_down2 = double_conv(64, 128)
    self.dconv_down3 = double_conv(128, 256)
    self.dconv_down4 = double_conv(256, 512)
    self.maxpool = nn.MaxPool2d(2)
    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 这里使用双线性插值
    self.dconv_up3 = double_conv(256 + 512, 256)
    self.dconv_up2 = double_conv(128 + 256, 128)
    self.dconv_up1 = double_conv(128 + 64, 64)
    self.conv_last = nn.Conv2d(64, n_class, 1) # 最后一层, 需要识别多少种目标,则输出多少个channel(n_class)

  def forward(self, x):
    conv1 = self.dconv_down1(x)
    x = self.maxpool(conv1) # 对应上图Unet编码层2
    conv2 = self.dconv_down2(x)
    x = self.maxpool(conv2)
    conv3 = self.dconv_down3(x)
    x = self.maxpool(conv3)
    x = self.dconv_down4(x) #到底了
    x = self.upsample(x) # 双线性插值,还原"图片"
    # 解码数据与对应编码数据concat使channel数增加, 弥补了单纯上采样导致的信息还原不足
    # 这一步很关键(也就是图Unet解码层1中数据变"厚")
    x = torch.cat([x, conv3], dim=1) 
    x = self.dconv_up3(x)
    x = self.upsample(x)        
    x = torch.cat([x, conv2], dim=1) # 256+128
    x = self.dconv_up2(x)# 
    x = self.upsample(x)        
    x = torch.cat([x, conv1], dim=1)
    x = self.dconv_up1(x)
    out = self.conv_last(x)
    return out
# 这里打印一下网络结构
model = Unet(6)
summary(model, input_size=(3, 224, 224))
数值化网络结构
6.损失函数
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()
# 这里使用两种损失函数加权
def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target) 
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    return loss

def print_metrics(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
    print("{}: {}".format(phase, ", ".join(outputs))) 

def train_model(model, optimizer, scheduler, num_epochs=25):
  best_model_wts = copy.deepcopy(model.state_dict())
  best_loss = 1e10
  for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-'*10)
    since = time.time()
    for phase in ['train', 'val']:
      if phase == 'train':
        scheduler.step()
        model.train()  # Set model to training mode
      else:
        model.eval()   # Set model to evaluate mode
      metrics = defaultdict(float)
      epoch_samples = 0

      for inputs, labels in dataloaders[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(inputs)
          loss = calc_loss(outputs, labels, metrics)
          if phase == 'train':
            loss.backward()
            optimizer.step()
        epoch_samples += inputs.size(0)
      print_metrics(metrics, epoch_samples, phase)
      epoch_loss = metrics['loss'] / epoch_samples

      if phase == 'val' and epoch_loss < best_loss:
        print("saving best model")
        best_loss = epoch_loss
        best_model_wts = copy.deepcopy(model.state_dict())
    time_elapsed = time.time() - since
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  print('Best val loss: {:4f}'.format(best_loss))
  # load best model weights
  model.load_state_dict(best_model_wts)
  return model
7.训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_class = 6
model = Unet(num_class).to(device)
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=40)
训练结果

相关文章

  • Unet图像分割

    Unet网络是一种图像语义分割网络,图像语义分割网络让计算机根据图像的语义来进行分割,例如让计算机在输入下面下图,...

  • 图像语义分割基础知识整理(CNN,FCN,Unet,CVPR-D

    整理的一些关于图像语义分割,CNN,FCN,Unet等的基础知识同时对CVPR-DeepGlobe路网分割竞赛部分...

  • U-Net

    主要参考资料:Unet的网站和论文。 U-Net最早用作生物图像的分割,后来在目标检测、图像转换,以及Tone M...

  • 用Unet实现图像分割(by pytorch)

    Segmentation Figure1来自CamVid database,专为目标识别(Object Decti...

  • Pytorch-UNet介绍

    简介 UNet网络主要用在医学图像分割任务上,网络的结构特点就是: 全卷积网络,没有全连接层,训练参数少,模型体积...

  • 论文泛读:《Esophageal Gross Tumor Vol

    简 介: 将 dense block 安插到 3D-Unet, 在食管道 CT肿瘤分割任务中好于 3D-Unet,...

  • Unet 多分类分割,附开源代码

    Unet图像分割在大多的开源项目中都是针对于二分类,理论来说,对于多分类问题,依旧可行。可小编尝试过很多的方法在原...

  • 图像分割

    图像分割 什么是图像分割? 图像分割就是预测图像中每一个像素所属的类别或者物体。图像分割有两个子问题,一个是只预测...

  • Unet

    网络结构 本文提出了一个分割网络——Unet,Unet借鉴了FCN网络,其网络结构包括两个对称部分:前面一部分网络...

  • 图像分割算法总结

    图像处理的很多任务都离不开图像分割。因为图像分割在cv中实在太重要(有用)了,就先把图像分割的常用算法做个总...

网友评论

    本文标题:Unet图像分割

    本文链接:https://www.haomeiwen.com/subject/vfprnctx.html