图像语义分割实践(二)数据增强与读取
Pytorch数据加载顺序
神经网络模型训练过程需要进行梯度更新,梯度更新可分三种方式。1.批梯度下降(batch gradient descent):一次所有数据批计算,过于复杂,计算缓慢;2.随机梯度下降(stochastic gradient descent):每次读一个数据,数据差异大,导致训练波动太大,收敛性不好;3.最小批量梯度下降(mini-batch gradient descent / SGD gradient descent):随机取一定量数据进行训练,既降低计算量,又能提高训练速度。
使用pytorch对数据进行批次量读取构建,首先了解其加载数据顺序分为以下三个点。
pytorch中加载数据的顺序分为以下三个点:
1."创建一个 dataset 对象"; 并加入 transforms 数据增强方案;
2."创建一个 dataloader 对象";
3."获取数据集的 mini_batch"; 循环 dataloader 对象, 获取训练样本送入模型进行训练;
其中,
"1.创建一个 dataset 对象", 继承 pytorch 的 torch.utils.data.Dataset; 一般需要含3个主要函数:
1.__init__: 传入数据, 或者直接加载固化的数据包;
2.__len__: 返回这个数据集一共有多少个item;
3.__getitem__: 返回一条训练数据, 并将其转换成tensor;
"2.创建一个 dataloader 对象", 采用 pytorch 的 torch.utils.data.DataLoader 整合成 mini_batch;
"3.获取数据集的 mini_batch"
Pytorch官方示例与实践改造
1.构建dataset对象.png 2.构建dataloader对象.png 3.索引minibatch数据.png数据加载万能模板
针对自己数据集进行分装,数据列表单元+数据增强单元是我们需要关注的点,所以只要在这两个函数进行改造,其他部分和官方的1.dataset对象,2.dataloader对象,3.mini_batch获取一致。
4.minibatch可视化.png
模板代码示例
######## py内置函数:help-文件架构, dir-代码架构 ########
import torch # 包含基本,加减乘除,张量操作,优化器'torch.optim', 数据索引 'torch.utils.data.DataLoader'
import torch.nn as nn # "类": 包含卷积,池化,激活,损失等 "nn.CrossEntropyLoss()"
import torch.nn.functional as F # "函数": 包含卷积,池化,激活,损失等 "F.cross_entropy()"
import torchvision # 包含图像算法的基本操作等 torchvision.models; torchvision.datasets;
import torchvision.transforms as T # "类": 包含图像增强方向等 "T.RandomCrop()"
import torchvision.transforms.functional as TF # "函数": 包含图像增强方向等 "TF.center_crop()"
import os
import glob
import math
import numpy as np
import random
from PIL import Image
import PIL
import matplotlib.pyplot as plt
#################### 构建 lines 可略 ####################
class MyLinesGetter(object):
def __init__(self, FilePath, dtype="seg"):
self.FilePath = FilePath
self.dtype = dtype # None="cls", "seg"
def getter(self):
self.datalines = []
with open(self.FilePath, "r") as f:
lines = f.read().splitlines()
if self.dtype is 'seg':
for line in lines:
img_dir, seg_dir = line.split(" ")[:2]
img_dir = os.path.join("data_flowers", "JPEGImages", img_dir)
seg_dir = os.path.join("data_flowers", "SegmentationClassRAW", seg_dir)
self.datalines.append([img_dir, seg_dir])
else:
raise "wrong dtype! check dtype on ['seg']!"
return self.datalines
#################### 创建 dataset class ####################
class SegmentDataset(torch.utils.data.Dataset): # 继承
def __init__(self, dataset, transforms=None):
self.dataset = dataset
self.transforms = transforms
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img_dir, seg_dir = self.dataset[idx]
img = Image.open(img_dir)
seg = Image.open(seg_dir)
if self.transforms is not None:
data_dict = self.transforms(img, seg)
img = data_dict['image']
seg = data_dict["mask"]
else:
img = TF.to_tensor(img)
seg = torch.as_tensor(np.array(seg), dtype=torch.int64)
return img, seg
pass
#################### 创建 transforms+Compose 增强方案 ####################
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, image, target=None, label=None):
image = TF.resize(image, self.size)
if target is not None:
target = TF.resize(target, self.size, interpolation=PIL.Image.BILINEAR) # PIL.Image.BILINEAR
if label is not None:
label = label
return image, target, label
pass
class ToTensor(object):
def __call__(self, image, target=None, label=None):
image = TF.to_tensor(image)
if target is not None:
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target, label
pass
# 可用 torchvision 里面的 compose, 为方便看过程,因此自己实现一遍
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, mask=None, label=None):
for t in self.transforms:
image, mask, label = t(image, mask, label)
return {'image':image, 'mask':mask, 'label':label}
pass
if __name__=="__main__":
# "1.创建一个 dataset 对象"
train_dataset = SegmentDataset(MyLinesGetter(FilePath="data_flowers/train.txt", dtype="seg").getter(),
transforms=Compose([Resize((256,256)), ToTensor(),]))
# "2.创建一个 dataloader 对象"
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
# "3.获取数据集的 mini_batch"
for (images, masks) in train_data_loader:
plt.figure(figsize=(20,20))
plt.imshow(np.hstack(images.permute(0,2,3,1)))
plt.show()
plt.figure(figsize=(20,20))
plt.imshow(np.hstack(masks))
plt.show()
break
网友评论