Resnet在图像分割领域的应用- Linknet
image.png
image.png
[1707.03718] LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation (arxiv.org)
1707.03718.pdf (arxiv.org)
image.png
输出的full-conv 实际上是反卷积
输入7x7的卷积核,输入特征3,64个卷积核,/2 表示图像缩放为原来的1/2
max-pool 3x3卷积核,/2表示对图像缩放为1/2
输出full-conv 反卷积,3x3卷积核,输入64个特征,输出32个特征,2表示图像放大为之前的2倍
中间的conv 用3x3卷积核,输入32 输出32,没有对图片进行缩放,只是一种特征的提取
最后输出full-conv反卷积,2x2 kernel,输入特征32,N个filter,N的取值取决于要分为多少类别,然后2对图片进行放大
image.png
输入 两个残差模块。
conv[(3x3),(m,n),/2] 3x3卷积核,输入m个特征,n个卷积核,图像缩放为原来的一半 image.png
image.png
image.png
image.png
image.png
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
from torchvision import transforms
import os
import glob
from PIL import Image
BATCH_SIZE = 128
# 绘制原图
# pil_img = Image.open('./data/hk/training/00001.png')
# np_img = np.array(pil_img)
# plt.imshow(np_img)
# plt.show()
# 绘制分割后的图
# pil_img = Image.open('./data/hk/training/00001_matte.png')
# np_img = np.array(pil_img)
# plt.imshow(np_img)
# plt.show()
# np_img.max(), np_img.min() # (255, 0)
# np_img.shape # (800, 600)
# np.unique(np_img) .., 255]) # array([ 0, .., 255]) 像素点0-255直接,不是2分类的0或者1
# 绘制像素点为0/1的图片
# pil_img = Image.open('./data/hk/training/00001_matte.png')
# np_img = np.array(pil_img)
# np_img[np_img>0]=1
# plt.imshow(np_img)
# plt.show()
# np.unique(np_img) # array([0, 1], dtype=uint8) 此时,像素只包括0和1. 这种变换对原有像素有一定的损失。
all_pics = glob.glob('./data/hk/training/*.png')
# all_pics[:5]
# ['./data/hk/training\\00001.png',
# './data/hk/training\\00001_matte.png',
# './data/hk/training\\00002.png',
# './data/hk/training\\00002_matte.png',
# './data/hk/training\\00003.png']
images = [p for p in all_pics if 'matte' not in p]
# len(images) # 1700
annotations = [p for p in all_pics if 'matte' in p]
# len(annotations) # 1700
np.random.seed(2021)
index = np.random.permutation(len(images))
images = np.array(images)[index]
anno = np.array(annotations)[index]
all_test_pics = glob.glob('./data/hk/testing/*.png')
test_images = [p for p in all_test_pics if 'matte' not in p]
test_anno = [p for p in all_test_pics if 'matte' in p]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
class Portrait_dataset(data.Dataset):
def __init__(self, img_paths, anno_paths): # 需要提高图片路径+分割图路径
self.imgs = img_paths
self.annos = anno_paths
def __getitem__(self, index): # 切片
img = self.imgs[index]
anno = self.annos[index]
pil_img = Image.open(img)
img_tensor = transform(pil_img) # 通过transform转换为tensor。 对于原图的处理
pil_anno = Image.open(anno)
anno_tensor = transform(pil_anno)
anno_tensor = torch.squeeze(anno_tensor).type(torch.long) # 默认转换后的尺寸是256*256*1,1个channel。 用squeeze去掉chanel
anno_tensor[anno_tensor > 0] = 1 # 如果大于0,就置为1
return img_tensor, anno_tensor
def __len__(self):
return len(self.imgs)
train_dataset = Portrait_dataset(images, anno)
test_dataset = Portrait_dataset(test_images, test_anno)
train_dl = data.DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
test_dl = data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
)
imgs_batch, annos_batch = next(iter(train_dl)) # 取出批次数据
# img_batch.shape # batch=8, channel=3, 大小256*256
# annos_batch.shape # batch=8, channel=1, 大小256*256. 用torch.squeese 就把torch为1的维度去掉了
# img = imgs_batch[0].permute(1,2,0).numpy() # 对第1张图片进行绘图。 permute将channel放到最后面
# anno = annos_batch[0].numpy() # anno图片没有channle这个属性,因此不需要用permute
# plt.subplot(1,2,1) # 绘制1行2列的第1张图
# plt.imshow(img)
# plt.subplot(1,2,2) # 绘制1行2列的第2张图
# plt.imshow(anno)
# 创建LinkNet模型
# 1. 编写卷积模块(卷积 + BN + activate)
# 2. 编写反卷积模块(反卷积+BN+activate)
# 3. 编码器(4个卷积模块)
# 4. 解码器(卷积模块+反卷积模块+卷积模块)
# 5. 实现整体网络结构
# 1. 卷积模块
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels,
k_size=3,
stride=1, # 通过stride来控制图片缩放
pad=1):
super(ConvBlock, self).__init__() # 继承父类的属性
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=k_size,
stride=stride,
padding=pad),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True) # 内存中不做备份,直接改变
)
def forward(self, x):
x = self.conv_relu(x)
return x
# 2. 反卷积模块
class DeconvBlock(nn.Module):
def __init__(self, in_channels, out_channels,
k_size=3,
stride=2,
pad=1,
padding=1):
super(DeconvBlock, self).__init__()
self.deconv = nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=k_size,
stride=stride,
padding=padding,
output_padding=pad)
self.bn = nn.BatchNorm2d(out_channels) # 反卷积需要对激活与BN层进行控制。
def forward(self, x, is_act=True): # is_act=True 表示需要激活。 如果为False,不进行BN与激活,直接返回反卷积结构
x = self.deconv(x)
if is_act:
x = torch.relu(self.bn(x))
return x
# 3.编码器
class EncodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeBlock, self).__init__()
self.conv1_1 = ConvBlock(in_channels, out_channels, stride=2) # 第一个卷积层,输入m,输出n,输出需要缩放为1/2,所以stride=2
self.conv1_2 = ConvBlock(out_channels, out_channels) # 输入 输出相同,没有做缩放。
self.conv2_1 = ConvBlock(out_channels, out_channels)
self.conv2_2 = ConvBlock(out_channels, out_channels)
self.shortcut = ConvBlock(in_channels, out_channels, stride=2) # shortcut连接。
def forward(self, x):
out1 = self.conv1_1(x)
out1 = self.conv2_1(out1)
residue = self.shortcut(x)
out2 = self.conv2_1(out1 + residue)
out2 = self.conv2_2(out2)
return out2 + out1
# 4. 解码器
class DecodeBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DecodeBlock, self).__init__()
self.conv1 = ConvBlock(in_channels, in_channels//4,
k_size=1, pad=0)
self.deconv = DeconvBlock(in_channels//4, in_channels//4)
self.conv2 = ConvBlock(in_channels//4, out_channels,
k_size=1, pad=0)
def forward(self, x):
x = self.conv1(x)
x = self.deconv(x)
x = self.conv2(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.init_conv = ConvBlock(3, 64,
k_size=7,
stride=2,
pad=3)
self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))
self.encode1 = EncodeBlock(64, 64)
self.encode2 = EncodeBlock(64, 128)
self.encode3 = EncodeBlock(128, 256)
self.encode4 = EncodeBlock(256, 512)
self.decode4 = DecodeBlock(512, 256)
self.decode3 = DecodeBlock(256, 128)
self.decode2 = DecodeBlock(128, 64)
self.decode1 = DecodeBlock(64, 64)
self.deconv_last1 = DeconvBlock(64, 32)
self.conv_last = ConvBlock(32, 32)
self.deconv_last2 = DeconvBlock(32, 2,
k_size=2,
pad=0,
padding=0)
def forward(self, x):
x = self.init_conv(x) # (6, 128, 128, 64)
x = self.init_maxpool(x) # (6, 64, 64, 64)
e1 = self.encode1(x) # (6, 32, 32, 64)
e2 = self.encode2(e1) # (6, 16, 16, 128)
e3 = self.encode3(e2) # (6, 8, 8, 256)
e4 = self.encode4(e3) # (6, 4, 4, 512)
d4 = self.decode4(e4) + e3
d3 = self.decode3(d4) + e2
d2 = self.decode2(d3) + e1
d1 = self.decode1(d2)
f1 = self.deconv_last1(d1)
f2 = self.conv_last(f1)
f3 = self.deconv_last2(f2, is_act=False)
return f3
model = Net()
if torch.cuda.is_available():
model.to('cuda')
loss_fn = nn.CrossEntropyLoss()
from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
model.train()
for x, y in trainloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim=1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
exp_lr_scheduler.step()
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / (total*256*256)
test_correct = 0
test_total = 0
test_running_loss = 0
model.eval()
with torch.no_grad():
for x, y in testloader:
if torch.cuda.is_available():
x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim=1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
epoch_test_loss = test_running_loss / len(testloader.dataset)
epoch_test_acc = test_correct / (test_total*256*256)
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy:', round(epoch_acc, 3),
'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy:', round(epoch_test_acc, 3)
)
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
epochs = 40
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
model,
train_dl,
test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
# 保存模型
# PATH = 'unet_model.pth'
# torch.save(model.state_dict(), PATH)
# 测试模型
# my_model = Net()
# my_model.load_state_dict(torch.load(PATH))
# num=3 # 取出3张图片
# image, mask = next(iter(test_dl))
# pred_mask = my_model(image)
# plt.figure(figsize=(10, 10))
# for i in range(num):
# plt.subplot(num, 3, i*num+1) # i从0开始, 第一行 第1张图片的原图
# plt.imshow(image[i].permute(1,2,0).cpu().numpy())
# plt.subplot(num, 3, i*num+2) # 实际的分割图
# plt.imshow(mask[i].cpu().numpy())
# plt.subplot(num, 3, i*num+3) # 预测出的分割图
# plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy()) # detach 取出实际结果
# train数据集上测试
# image, mask = next(iter(train_dl))
# pred_mask = my_model(image)
# plt.figure(figsize=(10, 10))
# for i in range(num):
# plt.subplot(num, 3, i*num+1)
# plt.imshow(image[i].permute(1,2,0).cpu().numpy())
# plt.subplot(num, 3, i*num+2)
# plt.imshow(mask[i].cpu().numpy())
# plt.subplot(num, 3, i*num+3)
# plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
IOU
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
epoch_iou = []
model.train()
for x, y in trainloader:
# if torch.cuda.is_available():
# x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim=1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
intersection = torch.logical_and(y, y_pred)
union = torch.logical_or(y, y_pred)
batch_iou = torch.true_divide(torch.sum(intersection),
torch.sum(union))
epoch_iou.append(batch_iou)
exp_lr_scheduler.step()
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / (total*256*256)
test_correct = 0
test_total = 0
test_running_loss = 0
epoch_test_iou = []
model.eval()
with torch.no_grad():
for x, y in testloader:
# if torch.cuda.is_available():
# x, y = x.to('cuda'), y.to('cuda')
y_pred = model(x)
loss = loss_fn(y_pred, y)
y_pred = torch.argmax(y_pred, dim=1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
intersection = torch.logical_and(y, y_pred)
union = torch.logical_or(y, y_pred)
batch_iou = torch.true_divide(torch.sum(intersection),
torch.sum(union))
epoch_test_iou.append(batch_iou)
epoch_test_loss = test_running_loss / len(testloader.dataset)
epoch_test_acc = test_correct / (test_total*256*256)
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy:', round(epoch_acc, 3),
'IOU:', round(np.mean(epoch_iou), 3))
print()
print(' ', 'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy:', round(epoch_test_acc, 3),
'test_iou:', round(np.mean(epoch_test_iou), 3)
)
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
epochs = 40
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
model,
train_dl,
test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
# 保存模型
PATH = 'linknet_model.pth'
torch.save(model.state_dict(), PATH)
# 测试模型
my_model = Net()
num=3
image, mask = next(iter(test_dl))
pred_mask = my_model(image)
plt.figure(figsize=(10, 10))
for i in range(num):
plt.subplot(num, 3, i*num+1)
plt.imshow(image[i].permute(1,2,0).cpu().numpy())
plt.subplot(num, 3, i*num+2)
plt.imshow(mask[i].cpu().numpy())
plt.subplot(num, 3, i*num+3)
plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
# 在train数据上测试
image, mask = next(iter(train_dl))
pred_mask = my_model(image)
plt.figure(figsize=(10, 10))
for i in range(num):
plt.subplot(num, 3, i*num+1)
plt.imshow(image[i].permute(1,2,0).cpu().numpy())
plt.subplot(num, 3, i*num+2)
plt.imshow(mask[i].cpu().numpy())
plt.subplot(num, 3, i*num+3)
plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())
网友评论