美文网首页
pytorch数据加载及预处理

pytorch数据加载及预处理

作者: callme周小伦 | 来源:发表于2019-05-14 15:03 被阅读0次

在进行pytorch项目之前,我们首先需要进行数据预处理,这里的数据分布主要分为两种:

  1. 所有的图片都在同一个文件夹中,这种情况一般会有labels的文本数据
  2. 不同种类的图片放置在不同的文件夹(这种情况不需要额外的labels数据)

两种数据加载及预处理方法

1. 首先我们来讲解第一种情况,在pytorch tutorial中DATA LOADING AND PROCESSING TUTORIAL章节中有详细讲解。

  • 在这种情况下我们一般会用到torch.utils.data.Dataset,继承Dataset类,并重载__init__, __len__, __getitem__三个函数,这里以pytorch tutorial中的示例为例:
class FaceDataset(Dataset):
    def __init__(self, root_dir, file_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.landmarks_fram = pd.read_csv(file_dir)

    def __len__(self):
        return len(self.landmarks_fram)

    def __getitem__(self, item):
        img_name = os.path.join(self.root_dir, self.landmarks_fram.iloc[item, 0])
        image = io.imread(img_name)

        landmarks = self.landmarks_fram.iloc[item, 1:].as_matrix()
        landmarks = landmarks.astype(float).reshape(-1,2)

        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)
        return sample

  • 这里的transform既可以使用torchvision.transforms中的转换操作,也可以进行自定义操作,在自定义是重载__init__, __call__函数:
class Rescale():
    def __init__(self, outsize):
        assert isinstance(outsize, (int, tuple))
        self.outsize = outsize
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h,w = image.shape[:2]

        if isinstance(self.outsize, int):
            if h>w:
                new_h = self.outsize* h/w
                new_w = self.outsize
            else:
                new_h = self.outsize
                new_w = self.outsize* w/h
        else:
            new_h,new_w = self.outsize

        new_h, new_w= int(new_h), int(new_w)

        image = transform.resize(image, [new_h,new_w])
        landmarks = landmarks* [new_h/h, new_w/w]

        return {'image': image, 'landmarks': landmarks}
class RandomCrop():
    def __init__(self, outsize):
        assert isinstance(outsize, (int, tuple))
        if isinstance(outsize, int):
            self.outsize = (outsize, outsize)
        else:
            assert len(outsize) == 2
            self.outsize =outsize

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.outsize

        top = np.random.randint(0, h- new_h)
        left = np.random.randint(0, w- new_w)

        image = image[top: top+new_w, left: left+ new_w]
        landmarks =landmarks- [left, top]

        return {'image': image, 'landmarks': landmarks}

class ToTensor():
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        image = image.transpose((2,0,1))

        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}
  • 重定义的FaceDataset,transform的调用方式与原来一样
face_dataset = FaceDateset(root_dir="data/faces/", file_dir='data/faces/face_landmarks.csv',
                           transform=transforms.Compose([
                               Rescale(256),
                               RandomCrop(224),
                               ToTensor()
                           ]))

2. 第二种情况的讲解以DCGAN tutorial的示例为例,这种情况下调用torchvision.datasets.ImageFolder()

dataset = dataset.ImageFolder(root=args.data_root,
                              transform=transforms.Compose([
                                  transforms.Resize(args.image_size),
                                  transforms.CenterCrop(args.image_size),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                              ]))

root参数为training data目录,这里data_root目录下只有一个文件夹img_align_celeba,因此ImageFolder以文件夹img_align_celeba名称作文classname,而后会将classname转换成数字表示,这里我们只有一类,因此所有标签为0

dataloader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size,
                        num_workers=args.workers)

在不进行额外定义的情况下,dataloader会包含training_images, training_labels两部分数据。例如:
real_betch = next(iter(dataloader))
DataLoader所返回的是一个list,dataloader[0]是training image tensor, 而dataloader[1]是training lebels tensor,这里面将每个文件夹下的image分为一类,所返回的labels是数值,此处,我们只有一个标签,所以lebels全为0。

相关文章

网友评论

      本文标题:pytorch数据加载及预处理

      本文链接:https://www.haomeiwen.com/subject/lxmsaqtx.html