美文网首页
pytorch 数据准备

pytorch 数据准备

作者: zestloveheart | 来源:发表于2019-07-18 22:06 被阅读0次

    准备工作

    本篇文章参考自pytorch官方教程,即末尾参考的第一篇,略去了不必要的乱七八糟的matlab显示功能,保留最实用的数据加载功能。
    先从这里下载并解压示例数据集。这里介绍如何创建一个dataloader去加载该文件夹内的数据集。

    import torch
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms, utils
    import os
    import pandas as pd
    from skimage import io, transform
    import numpy as np
    import matplotlib.pyplot as plt
    

    Dataset Class

    torch.utils.data.Dataset是一个抽象类,代表了一个数据集。自定义数据集的时候需要重写两个方法。
    __len__使得len(dataset)可以返回dataset的大小
    __getitem__支持dataset[i]可以取出第i个数据。
    下面为我们的数据集创建一个dataset类,首先会在__init__方法中读取csv文件,在__getitem__方法中读取图片,这样可以节约内存,根据需要读取图片,而不是一次性加载图片到内存中。

    class FaceLandmarksDataset(Dataset):
        """Face Landmarks dataset."""
        def __init__(self, csv_file, root_dir, transform=None):
            """
            Args:
                csv_file (string): Path to the csv file with annotations.
                root_dir (string): Directory with all the images.
                transform (callable, optional): Optional transform to be applied on a sample.
            """
            self.landmarks_frame = pd.read_csv(csv_file)
            self.root_dir = root_dir
            self.transform = transform
    
        def __len__(self):
            return len(self.landmarks_frame)
    
        def __getitem__(self, idx):
            img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])
            image = io.imread(img_name)
            landmarks = self.landmarks_frame.iloc[idx, 1:]
            landmarks = np.array([landmarks])
            landmarks = landmarks.astype('float').reshape(-1, 2)
            sample = {'image': image, 'landmarks': landmarks}
            if self.transform:
                sample = self.transform(sample)
            return sample
    face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')
    

    至此数据就可以从face_dataset中读取了。

    变换

    可以看到文件夹内的图片大小都不一致,大多数的网络都需要接受统一大小,所以需要对数据进行一些预处理,例如缩放,随机裁剪,转化成张量。
    我们会将这些方法写道一个可调用的类中,而不是简单的函数中,如此一来变换的参数就不用每次调用都传递一次。所以我们需要在类中实现__call__方法,有必要的话还要实现__init__方法。
    我们可以像下面这样调用。

    tsfm = Transform(params)
    transformed_sample = tsfm(sample)
    

    像下面这样定义

    class Rescale(object):
        """Rescale the image in a sample to a given size.
        Args:
            output_size (tuple or int): Desired output size. 
                If tuple, output is matched to output_size. 
                If int, smaller of image edges is matched to output_size keeping aspect ratio the same.
        """
        def __init__(self, output_size):
            assert isinstance(output_size, (int, tuple))
            self.output_size = output_size
    
        def __call__(self, sample):
            image, landmarks = sample['image'], sample['landmarks']
            h, w = image.shape[:2]
            if isinstance(self.output_size, int):
                if h > w:
                    new_h, new_w = self.output_size * h / w, self.output_size
                else:
                    new_h, new_w = self.output_size, self.output_size * w / h
            else:
                new_h, new_w = self.output_size
            new_h, new_w = int(new_h), int(new_w)
            img = transform.resize(image, (new_h, new_w))
            # h and w are swapped for landmarks because for images,
            # x and y axes are axis 1 and 0 respectively
            landmarks = landmarks * [new_w / w, new_h / h]
            return {'image': img, 'landmarks': landmarks}
    
    class RandomCrop(object):
        """Crop randomly the image in a sample.
        Args:
            output_size (tuple or int): Desired output size. If int, square crop is made.
        """
        def __init__(self, output_size):
            assert isinstance(output_size, (int, tuple))
            if isinstance(output_size, int):
                self.output_size = (output_size, output_size)
            else:
                assert len(output_size) == 2
                self.output_size = output_size
    
        def __call__(self, sample):
            image, landmarks = sample['image'], sample['landmarks']
            h, w = image.shape[:2]
            new_h, new_w = self.output_size
            top = np.random.randint(0, h - new_h)
            left = np.random.randint(0, w - new_w)
            image = image[top: top + new_h, left: left + new_w]
            landmarks = landmarks - [left, top]
            return {'image': image, 'landmarks': landmarks}
    
    class ToTensor(object):
        """Convert ndarrays in sample to Tensors."""
        def __call__(self, sample):
            image, landmarks = sample['image'], sample['landmarks']
            # swap color axis because
            # numpy image: H x W x C
            # torch image: C X H X W
            image = image.transpose((2, 0, 1))
            return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}
    

    组合变换

    如果我们需要做很最多变换,就需要把这些类组合到一起。像下面这样

    scale = Rescale(256)
    crop = RandomCrop(128)
    composed = transforms.Compose([Rescale(256),
                                   RandomCrop(224)])
    
    # Apply each of the above transforms on sample.
    fig = plt.figure()
    sample = face_dataset[65]
    for i, tsfrm in enumerate([scale, crop, composed]):
        transformed_sample = tsfrm(sample)
    
        ax = plt.subplot(1, 3, i + 1)
        plt.tight_layout()
        ax.set_title(type(tsfrm).__name__)
    plt.show()
    

    迭代器

    下面把这些变换都结合到一起创建一个dataset。所有图片都从文件名中,变换在读取图片是生效,每一个变换都是随机的。

    transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                               root_dir='data/faces/',
                                               transform=transforms.Compose([
                                                   Rescale(256),
                                                   RandomCrop(224),
                                                   ToTensor()
                                               ]))
    

    然而我们丢失了一些特征,比如数据的批大小,数据随机,多gpu并行处理。可以用dataloader来玩。

    dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
    

    然后可以遍历dataloader,读取里面的数据。

    还有一点没看完

    ## Afterword: torchvision
    [Afterword: torchvision](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#afterword-torchvision)
    
    In this tutorial, we have seen how to write and use datasets, transforms and dataloader. `torchvision` package provides some common datasets and transforms. You might not even have to write custom classes. One of the more generic datasets available in torchvision is `ImageFolder`. It assumes that images are organized in the following way:
    
    <pre style="box-sizing: border-box; font-family: IBMPlexMono, SFMono-Regular, Menlo, Monaco, Consolas, &quot;Liberation Mono&quot;, &quot;Courier New&quot;, monospace; font-size: 14px; margin-top: 0px; margin-bottom: 2.5rem; overflow: auto; display: block; color: rgb(33, 37, 41); padding: 1.375rem; background-color: rgb(243, 244, 247); white-space: pre-wrap; overflow-wrap: break-word;">root/ants/xxx.png
    root/ants/xxy.jpeg
    root/ants/xxz.png
    .
    .
    .
    root/bees/123.jpg
    root/bees/nsdf3.png
    root/bees/asd932_.png
    </pre>
    
    where ‘ants’, ‘bees’ etc. are class labels. Similarly generic transforms which operate on `PIL.Image` like `RandomHorizontalFlip`, `Scale`, are also available. You can use these to write a dataloader like this:
    
    <pre style="box-sizing: border-box; font-family: IBMPlexMono, SFMono-Regular, Menlo, Monaco, Consolas, &quot;Liberation Mono&quot;, &quot;Courier New&quot;, monospace; font-size: 14px; margin-top: 0px; margin-bottom: 2.5rem; overflow: auto; display: block; color: rgb(33, 37, 41); padding: 1.375rem; background-color: rgb(243, 244, 247); white-space: pre-wrap; overflow-wrap: break-word;">import torch
    from torchvision import transforms, datasets
    
    data_transform = transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                               transform=data_transform)
    dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                                 batch_size=4, shuffle=True,
                                                 num_workers=4)</pre>
    
    

    参考

    Data Loading and Processing

    相关文章

      网友评论

          本文标题:pytorch 数据准备

          本文链接:https://www.haomeiwen.com/subject/pvmrlctx.html