美文网首页图像分割
图像语义分割实践(二)数据增强与读取

图像语义分割实践(二)数据增强与读取

作者: 智能之心 | 来源:发表于2022-04-23 17:53 被阅读0次

    图像语义分割实践(二)数据增强与读取

    Pytorch数据加载顺序

    神经网络模型训练过程需要进行梯度更新,梯度更新可分三种方式。1.批梯度下降(batch gradient descent):一次所有数据批计算,过于复杂,计算缓慢;2.随机梯度下降(stochastic gradient descent):每次读一个数据,数据差异大,导致训练波动太大,收敛性不好;3.最小批量梯度下降(mini-batch gradient descent / SGD gradient descent):随机取一定量数据进行训练,既降低计算量,又能提高训练速度。

    使用pytorch对数据进行批次量读取构建,首先了解其加载数据顺序分为以下三个点。

    pytorch中加载数据的顺序分为以下三个点:
    1."创建一个 dataset 对象"; 并加入 transforms 数据增强方案;
    2."创建一个 dataloader 对象";
    3."获取数据集的 mini_batch"; 循环 dataloader 对象, 获取训练样本送入模型进行训练;
    
    其中, 
    "1.创建一个 dataset 对象", 继承 pytorch 的 torch.utils.data.Dataset; 一般需要含3个主要函数:
        1.__init__:    传入数据, 或者直接加载固化的数据包;
        2.__len__:     返回这个数据集一共有多少个item;
        3.__getitem__:  返回一条训练数据, 并将其转换成tensor;
    
    "2.创建一个 dataloader 对象", 采用 pytorch 的 torch.utils.data.DataLoader 整合成 mini_batch;
    
    "3.获取数据集的 mini_batch"
    

    Pytorch官方示例与实践改造

    Pytorch官方示例与实践改造

    1.构建dataset对象.png 2.构建dataloader对象.png 3.索引minibatch数据.png

    数据加载万能模板

    针对自己数据集进行分装,数据列表单元+数据增强单元是我们需要关注的点,所以只要在这两个函数进行改造,其他部分和官方的1.dataset对象,2.dataloader对象,3.mini_batch获取一致。


    4.minibatch可视化.png

    模板代码示例

    ######## py内置函数:help-文件架构, dir-代码架构 ########
    import torch # 包含基本,加减乘除,张量操作,优化器'torch.optim', 数据索引 'torch.utils.data.DataLoader'
    import torch.nn as nn # "类":   包含卷积,池化,激活,损失等 "nn.CrossEntropyLoss()"
    import torch.nn.functional as F  # "函数": 包含卷积,池化,激活,损失等 "F.cross_entropy()"
    import torchvision # 包含图像算法的基本操作等 torchvision.models; torchvision.datasets;
    import torchvision.transforms as T # "类":   包含图像增强方向等 "T.RandomCrop()"
    import torchvision.transforms.functional as TF # "函数": 包含图像增强方向等 "TF.center_crop()"
    import os
    import glob
    import math
    import numpy as np
    import random
    from PIL import Image
    import PIL
    import matplotlib.pyplot as plt
    
    
    #################### 构建 lines 可略 ####################
    class MyLinesGetter(object):
        def __init__(self, FilePath, dtype="seg"):
            self.FilePath = FilePath
            self.dtype = dtype # None="cls", "seg"
        def getter(self):
            self.datalines = []
            with open(self.FilePath, "r") as f:
                lines = f.read().splitlines()
                if self.dtype is 'seg':
                    for line in lines:
                        img_dir, seg_dir = line.split(" ")[:2]
                        img_dir = os.path.join("data_flowers", "JPEGImages", img_dir)
                        seg_dir = os.path.join("data_flowers", "SegmentationClassRAW", seg_dir)
                        self.datalines.append([img_dir, seg_dir])
                else:
                    raise "wrong dtype! check dtype on ['seg']!"
            return self.datalines
    
    #################### 创建 dataset class ####################
    class SegmentDataset(torch.utils.data.Dataset): # 继承
        def __init__(self, dataset, transforms=None):
            self.dataset = dataset
            self.transforms = transforms
        def __len__(self):
            return len(self.dataset)
        def __getitem__(self, idx):
            img_dir, seg_dir = self.dataset[idx]
            img = Image.open(img_dir)
            seg = Image.open(seg_dir)
            if self.transforms is not None:
                data_dict = self.transforms(img, seg)
                img = data_dict['image']
                seg = data_dict["mask"]
            else:
                img = TF.to_tensor(img)
                seg = torch.as_tensor(np.array(seg), dtype=torch.int64)
            return img, seg
        pass
    
    #################### 创建 transforms+Compose 增强方案 ####################
    class Resize(object):
        def __init__(self, size):
            self.size = size
        
        def __call__(self, image, target=None, label=None):
            image = TF.resize(image, self.size)
            if target is not None:
                target = TF.resize(target, self.size, interpolation=PIL.Image.BILINEAR) # PIL.Image.BILINEAR
            if label is not None:
                label = label
            return image, target, label
        pass
    
    class ToTensor(object):
        def __call__(self, image, target=None, label=None):
            image = TF.to_tensor(image)
            if target is not None:
                target = torch.as_tensor(np.array(target), dtype=torch.int64)
            return image, target, label
        pass
    
    # 可用 torchvision 里面的 compose, 为方便看过程,因此自己实现一遍
    class Compose(object):
        def __init__(self, transforms):
            self.transforms = transforms
        def __call__(self, image, mask=None, label=None):
            for t in self.transforms:
                image, mask, label = t(image, mask, label)
            return {'image':image, 'mask':mask, 'label':label}
        pass
    
    if __name__=="__main__":
        # "1.创建一个 dataset 对象"
        train_dataset = SegmentDataset(MyLinesGetter(FilePath="data_flowers/train.txt", dtype="seg").getter(), 
                                       transforms=Compose([Resize((256,256)), ToTensor(),]))
    
        # "2.创建一个 dataloader 对象"
        train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
    
        # "3.获取数据集的 mini_batch"
        for (images, masks) in train_data_loader:
            plt.figure(figsize=(20,20))
            plt.imshow(np.hstack(images.permute(0,2,3,1)))
            plt.show()
            plt.figure(figsize=(20,20))
            plt.imshow(np.hstack(masks))
            plt.show()
            break
    

    参考链接

    植物素材库
    代码高亮
    Pytorch dataset&dataloader
    图像语义分割实践

    相关文章

      网友评论

        本文标题:图像语义分割实践(二)数据增强与读取

        本文链接:https://www.haomeiwen.com/subject/nmyaertx.html