美文网首页
pytorch基础四(DataLoader)

pytorch基础四(DataLoader)

作者: 永远学习中 | 来源:发表于2018-12-02 11:52 被阅读0次

    本人学习pytorch主要参考官方文档莫烦Python中的pytorch视频教程。
    后文主要是对pytorch官网的文档的总结。

    加载csv文件

    • scikit-image用于加载图片并进行转化
    • pandas简单的解析csv格式的文件
      下载faces集合解压缩放在‘faces/’.
    landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
    
    n = 65
    # 获取第65行第0列数据
    img_name = landmarks_frame.iloc[n, 0]
    # 将第1列以后的转化为矩阵
    landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
    # 将原本一行的数据转化为两行,也就是一列为x,y
    landmarks = landmarks.astype('float').reshape(-1, 2)
    
    print('Image name: {}'.format(img_name))
    print('Landmarks shape: {}'.format(landmarks.shape))
    print('First 4 Landmarks: {}'.format(landmarks[:4]))
    

    pytorch数据库类

    torch.utils.data.Dataset是数据库的虚类,自己的数据库类应该继承Dataset类,并重写下面的方法。

    • __len__,该方法返回数据库的大小
    • __getitem__,该方法为了支持通过dataset[i]获取第i个样本
    import os
    import pandas as pd
    from skimage import io
    import matplotlib.pyplot as plt
    from torch.utils.data import Dataset
    
    def show_landmarks(image, landmarks):
        plt.imshow(image)
        plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
        plt.pause(0.001) 
    
    
    class FaceLandmarksDataset(Dataset):
        def __init__(self, csv_file, root_dir, transform=None):
            self.landmarks_frame = pd.read_csv(csv_file)
            self.root_dir = root_dir
            self.transform = transform
    
        def __len__(self):
            return len(self.landmarks_frame)
    
        def __getitem__(self, idx):
            img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
            image = io.imread(img_name)
            landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
            landmarks = landmarks.astype('float').reshape(-1, 2)
            sample = {'image': image, 'landmarks': landmarks}
            if self.transform:
                sample = self.transform(sample)
            return sample
    
    face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                        root_dir='faces/')
    
    fig = plt.figure()
    
    for i in range(len(face_dataset)):
        sample = face_dataset[i]
        print(i, sample['image'].shape, sample['landmarks'].shape)
        ax = plt.subplot(1, 4, i + 1)
        plt.tight_layout()
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')
        show_landmarks(**sample)
        if i == 3:
            plt.show()
            break
    

    格式转化

    有三种格式转化:图像尺寸、随机裁切、将numpy图像转化为troch格式图像。

    class Rescale(object):
        def __init__(self, output_size):
            assert isinstance(output_size, (int, tuple))
            self.output_size = output_size
    
        def __call__(self, sample):
            image, landmarks = sample['image'], sample['landmarks']
            h, w = image.shape[:2]
            if isinstance(self.output_size, int):
                if h > w:
                    new_h, new_w = self.output_size * h / w, self.output_size
                else:
                    new_h, new_w = self.output_size, self.output_size * w / h
            else:
                new_h, new_w = self.output_size
            new_h, new_w = int(new_h), int(new_w)
            img = transform.resize(image, (new_h, new_w))
            landmarks = landmarks * [new_w / w, new_h / h]
            return {'image': img, 'landmarks': landmarks}
    
    class RandomCrop(object):
        def __init__(self, output_size):
            assert isinstance(output_size, (int, tuple))
            if isinstance(output_size, int):
                self.output_size = (output_size, output_size)
            else:
                assert len(output_size) == 2
                self.output_size = output_size
    
        def __call__(self, sample):
            image, landmarks = sample['image'], sample['landmarks']
            h, w = image.shape[:2]
            new_h, new_w = self.output_size
            top = np.random.randint(0, h - new_h)
            left = np.random.randint(0, w - new_w)
            image = image[top: top + new_h,
                          left: left + new_w]
            landmarks = landmarks - [left, top]
            return {'image': image, 'landmarks': landmarks}
    
    class ToTensor(object):
        def __call__(self, sample):
            image, landmarks = sample['image'], sample['landmarks']
            image = image.transpose((2, 0, 1))
            return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
    #调用方式1,其中sample为通过数据库类得到的样本
    scale = Rescale(256)
    crop = RandomCrop(128)
    composed = transforms.Compose([Rescale(256),
                                   RandomCrop(224)])
    
    transformed_sample = scale(sample)
    transformed_sample = crop(sample)
    transformed_sample = composed(sample)
    #调用方式2,通过transforms.Compose
    transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                               root_dir='faces/',
                                               transform=transforms.Compose([
                                                   Rescale(256),
                                                   RandomCrop(224),
                                                   ToTensor()
                                               ]))
    
    for i in range(len(transformed_dataset)):
        sample = transformed_dataset[i]
        print(i, sample['image'].size(), sample['landmarks'].size())
        if i == 3:
            break
    

    批次加载图片、打乱顺序、并行加载

    • torch.utils.data.DataLoader是一个提供特征的迭代器。
    # 数据加载函数、批次、是否打乱顺序、并行数量
    dataloader = DataLoader(transformed_dataset, batch_size=4,
                            shuffle=True, num_workers=4)
    

    torchvision类

    该类提供了一下常见的数据库以及转换函数,具体还是见Transfer教程

    import torch
    from torchvision import transforms, datasets
    #将转换组合
    data_transform = transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                               transform=data_transform)
    dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                                 batch_size=4, shuffle=True,
                                                 num_workers=4)
    

    相关文章

      网友评论

          本文标题:pytorch基础四(DataLoader)

          本文链接:https://www.haomeiwen.com/subject/tpzdcqtx.html