美文网首页程序员
学习pytorch: 数据加载和处理

学习pytorch: 数据加载和处理

作者: 张王李刘赵孙杨 | 来源:发表于2018-11-16 23:44 被阅读607次

    简介

    结合官方tutorials源码以及部分博客写出此文。

    pytorch的数据加载和处理相对容易的多,常见的两种形式的导入:

    1. 一种是整个数据集都在一个文件夹下,内部再另附一个label文件,说明每个文件夹的状态,如这个数据库。这种存放数据的方式可能更适合在非分类问题上得到应用。
    2. 一种则是更适合使用在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。其形式如下:

    root/ants/xxx.png
    root/ants/xxy.jpeg
    root/ants/xxz.png
    .
    .
    .
    root/bees/123.jpg
    root/bees/nsdf3.png
    root/bees/asd932_.png

    本文首先结合官方turorials介绍第一种方法,以了解其数据加载的原理;然后以代码形式简单介绍第二种方法。其中第二种方法和第一种方法的原理相同,其差别在于第二种方法运用了trochvision中提供的已写好的工具ImageFolder,因此实现起来更为简单。

    第一种

    Dataset class

    torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

    1. __len__: 覆写这个方法使得len(dataset)可以返回整个数据集的大小
    2. __getitem__: 覆写这个方法使得dataset[i]可以返回数据集中第i个样本
    3. 不覆写这两个方法会直接返回错误,其源码如下:
        def __getitem__(self, index):
            raise NotImplementedError
    
        def __len__(self):
            raise NotImplementedError
    
    

    这里我随便从网上下载了20张图像,10张小猫,10张小狗。为了省事儿(只是想验证下继承Dataset类是否好用),我没有给数据集增加标签文件,而是直接把1-10号定义为小猫,11-20号定义为小狗,这样会给__len____getitem__减小麻烦,其目录结构如下:

    目录结构

    建立的自定义类如下:

    from torch.utils.data import DataLoader, Dataset
    from skimage import io, transform
    import matplotlib.pyplot as plt 
    import os 
    import torch
    from torchvision import transforms
    import numpy as np 
    
    class AnimalData(Dataset):
        def __init__(self, root_dir, transform=None):
            self.root_dir = root_dir
            self.transform = transform
        
        def __len__(self):
            return 20
    
        def __getitem__(self, idx):
            filenames = os.listdir(self.root_dir)
            filename = filenames[idx]
            img = io.imread(os.path.join(self.root_dir, filename))
            # print filename[:-5]
            if (int(filename[:-5]) > 10):
                lable = np.array([0])
            else:
                lable = np.array([1])
            sample = {'image': img, 'lable':lable}
            
            if self.transform:
                sample = self.transform(sample)
            return sample
    

    Transforms & Compose transforms

    可以注意到上一节中AnimalData类中__init__中有个transform参数,这也是这一节中要讲清楚的问题。
    从网上随便下载的图片必然大小不一,而cnn的结构却要求输入图像要有固定的大小;numpy中的图像通道定义为H, W, C,而pytorch中的通道定义为C, H, W; pytorch中输入数据需要将numpy array改为tensor类型;输入数据往往需要归一化,等等。
    基于以上考虑,我们可以自定义一些Callable的类,然后作为trasform参数传递给上一节定义的dataset类。为了更加方便,torchvision.transforms.Compose提供了Compose类,可以一次性将我们自定义的callable类传递给dataset类,直接得到转换后的数据。
    这里我直接copy教程上的三个类:Rescale, RandomCrop, ToTensor,稍作改动,适应我的数据库。

    class Rescale(object):
        """Rescale the image in a sample to a given size.
    
        Args:
            output_size (tuple or int): Desired output size. If tuple, output is
                matched to output_size. If int, smaller of image edges is matched
                to output_size keeping aspect ratio the same.
        """
    
        def __init__(self, output_size):
            assert isinstance(output_size, (int, tuple))
            self.output_size = output_size
    
        def __call__(self, sample):
            image, lable = sample['image'], sample['lable']
    
            h, w = image.shape[:2]
            if isinstance(self.output_size, int):
                if h > w:
                    new_h, new_w = self.output_size * h / w, self.output_size
                else:
                    new_h, new_w = self.output_size, self.output_size * w / h
            else:
                new_h, new_w = self.output_size
    
            new_h, new_w = int(new_h), int(new_w)
    
            img = transform.resize(image, (new_h, new_w))
    
            # h and w are swapped for lable because for images,
            # x and y axes are axis 1 and 0 respectively
            # lable = lable * [new_w / w, new_h / h]
    
            return {'image': img, 'lable': lable}
    
    class RandomCrop(object):
        """Crop randomly the image in a sample.
    
        Args:
            output_size (tuple or int): Desired output size. If int, square crop
                is made.
        """
    
        def __init__(self, output_size):
            assert isinstance(output_size, (int, tuple))
            if isinstance(output_size, int):
                self.output_size = (output_size, output_size)
            else:
                assert len(output_size) == 2
                self.output_size = output_size
    
        def __call__(self, sample):
            image, lable = sample['image'], sample['lable']
    
            h, w = image.shape[:2]
            new_h, new_w = self.output_size
    
            top = np.random.randint(0, h - new_h)
            left = np.random.randint(0, w - new_w)
    
            image = image[top: top + new_h,
                          left: left + new_w]
    
            # lable = lable - [left, top]
    
            return {'image': image, 'lable': lable}
    
    class ToTensor(object):
        """Convert ndarrays in sample to Tensors."""
    
        def __call__(self, sample):
            image, lable = sample['image'], sample['lable']
            # print lable
     
            # swap color axis because
            # numpy image: H x W x C
            # torch image: C X H X W
            image = image.transpose((2, 0, 1))
            return {'image': torch.from_numpy(image),
                    'lable': torch.from_numpy(lable)}
    

    定义好callable类之后,通过torchvision.transforms.Compose将上述三个类结合在一起,传递给AnimalData类中的transform参数即可。

    trsm = transforms.Compose([Rescale(256),
                                RandomCrop(224),
                                ToTensor()])
    data = AnimalData('./all', transform=trsm)
    

    Iterating through the dataset

    上一节中得到data实例之后可以通过for循环来一个一个读取数据,现在这是效率低下的。torch.utils.data.DadaLoader类解决了上述问题。其主要有如下特点:

    • Batching the data
    • Shuffling the data
    • Load the data in parallel using multiprocessing workers.

    实现起来也很简单:

    dataloader = DataLoader(data, batch_size=4, shuffle=True, num_workers=4)
    for i_batch, bach_data in enumerate(dataloader):
        print i_batch 
        print bach_data['image'].size()
        print bach_data['lable']
    

    第二种

    torchvision

    pytorch几乎将上述所有工作都封装起来供我们使用,其中一个工具就是torchvision.datasets.ImageFolder,用于加载用户自定义的数据,要求我们的数据要有如下结构:

    root/ants/xxx.png
    root/ants/xxy.jpeg
    root/ants/xxz.png
    .
    .
    .
    root/bees/123.jpg
    root/bees/nsdf3.png
    root/bees/asd932_.png

    torchvision.transforms中也封装了各种各样的数据处理的工具,如Resize, ToTensor等等功能供我们使用。
    修改我下载的数据库结构如下:

    method2_tree

    加载数据代码如下:

    from torchvision import transforms, utils
    from torchvision import datasets
    import torch
    import matplotlib.pyplot as plt 
    
    train_data = datasets.ImageFolder('./data1', transform=transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]))
    
    train_loader = torch.utils.data.DataLoader(train_data,
                                                batch_size=4,
                                                shuffle=True,
                                                )
                                                
    print len(train_loader)
    for i_batch, img in enumerate(train_loader):
        if i_batch == 0:
            print(img[1])
            fig = plt.figure()
            grid = utils.make_grid(img[0])
            plt.imshow(grid.numpy().transpose((1, 2, 0)))
            plt.show()
        break
    

    结果图:


    make_grid

    附录

    最后欣赏一段torchvision源码

    # vision/torchvision/datasets/folder.py
    
    import torch.utils.data as data
    
    from PIL import Image
    import os
    import os.path
    
    IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
    
    
    def is_image_file(filename):
        """Checks if a file is an image.
        Args:
            filename (string): path to a file
        Returns:
            bool: True if the filename ends with a known image extension
        """
        filename_lower = filename.lower()
        return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
    
    
    def find_classes(dir):
        classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx
    
    
    def make_dataset(dir, class_to_idx):
        images = []
        dir = os.path.expanduser(dir)
        for target in sorted(os.listdir(dir)):
            d = os.path.join(dir, target)
            if not os.path.isdir(d):
                continue
    
            for root, _, fnames in sorted(os.walk(d)):
                for fname in sorted(fnames):
                    if is_image_file(fname):
                        path = os.path.join(root, fname)
                        item = (path, class_to_idx[target])
                        images.append(item)
    
        return images
    
    
    def pil_loader(path):
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    
    
    def accimage_loader(path):
        import accimage
        try:
            return accimage.Image(path)
        except IOError:
            # Potentially a decoding problem, fall back to PIL.Image
            return pil_loader(path)
    
    
    def default_loader(path):
        from torchvision import get_image_backend
        if get_image_backend() == 'accimage':
            return accimage_loader(path)
        else:
            return pil_loader(path)
    
    
    class ImageFolder(data.Dataset):
        """A generic data loader where the images are arranged in this way: ::
            root/dog/xxx.png
            root/dog/xxy.png
            root/dog/xxz.png
            root/cat/123.png
            root/cat/nsdf3.png
            root/cat/asd932_.png
        Args:
            root (string): Root directory path.
            transform (callable, optional): A function/transform that  takes in an PIL image
                and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
                target and transforms it.
            loader (callable, optional): A function to load an image given its path.
         Attributes:
            classes (list): List of the class names.
            class_to_idx (dict): Dict with items (class_name, class_index).
            imgs (list): List of (image path, class_index) tuples
        """
    
        def __init__(self, root, transform=None, target_transform=None,
                     loader=default_loader):
            classes, class_to_idx = find_classes(root)
            imgs = make_dataset(root, class_to_idx)
            if len(imgs) == 0:
                raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                                   "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
    
            self.root = root
            self.imgs = imgs
            self.classes = classes
            self.class_to_idx = class_to_idx
            self.transform = transform
            self.target_transform = target_transform
            self.loader = loader
    
        def __getitem__(self, index):
            """
            Args:
                index (int): Index
            Returns:
                tuple: (image, target) where target is class_index of the target class.
            """
            path, target = self.imgs[index]
            img = self.loader(path)
            if self.transform is not None:
                img = self.transform(img)
            if self.target_transform is not None:
                target = self.target_transform(target)
    
            return img, target
    
        def __len__(self):
            return len(self.imgs)
    
        def __repr__(self):
            fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
            fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
            fmt_str += '    Root Location: {}\n'.format(self.root)
            tmp = '    Transforms (if any): '
            fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
            tmp = '    Target Transforms (if any): '
            fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
            return fmt_str
    

    参考

    [1]. Data Loading and Processing Tutorial
    [2]. github: pytorch/torch/utils/data/dataset.py
    [3]. github: vision/torchvision/datasets/folder.py
    [4]. csdn

    相关文章

      网友评论

        本文标题:学习pytorch: 数据加载和处理

        本文链接:https://www.haomeiwen.com/subject/srovfqtx.html