美文网首页
pytorch基础四(DataLoader)

pytorch基础四(DataLoader)

作者: 永远学习中 | 来源:发表于2018-12-02 11:52 被阅读0次

本人学习pytorch主要参考官方文档莫烦Python中的pytorch视频教程。
后文主要是对pytorch官网的文档的总结。

加载csv文件

  • scikit-image用于加载图片并进行转化
  • pandas简单的解析csv格式的文件
    下载faces集合解压缩放在‘faces/’.
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
# 获取第65行第0列数据
img_name = landmarks_frame.iloc[n, 0]
# 将第1列以后的转化为矩阵
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
# 将原本一行的数据转化为两行,也就是一列为x,y
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

pytorch数据库类

torch.utils.data.Dataset是数据库的虚类,自己的数据库类应该继承Dataset类,并重写下面的方法。

  • __len__,该方法返回数据库的大小
  • __getitem__,该方法为了支持通过dataset[i]获取第i个样本
import os
import pandas as pd
from skimage import io
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

def show_landmarks(image, landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001) 


class FaceLandmarksDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)
        return sample

face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                    root_dir='faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]
    print(i, sample['image'].shape, sample['landmarks'].shape)
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)
    if i == 3:
        plt.show()
        break

格式转化

有三种格式转化:图像尺寸、随机裁切、将numpy图像转化为troch格式图像。

class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        landmarks = landmarks * [new_w / w, new_h / h]
        return {'image': img, 'landmarks': landmarks}

class RandomCrop(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        image = image[top: top + new_h,
                      left: left + new_w]
        landmarks = landmarks - [left, top]
        return {'image': image, 'landmarks': landmarks}

class ToTensor(object):
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
#调用方式1,其中sample为通过数据库类得到的样本
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

transformed_sample = scale(sample)
transformed_sample = crop(sample)
transformed_sample = composed(sample)
#调用方式2,通过transforms.Compose
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                           root_dir='faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]
    print(i, sample['image'].size(), sample['landmarks'].size())
    if i == 3:
        break

批次加载图片、打乱顺序、并行加载

  • torch.utils.data.DataLoader是一个提供特征的迭代器。
# 数据加载函数、批次、是否打乱顺序、并行数量
dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

torchvision类

该类提供了一下常见的数据库以及转换函数,具体还是见Transfer教程

import torch
from torchvision import transforms, datasets
#将转换组合
data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

相关文章

网友评论

      本文标题:pytorch基础四(DataLoader)

      本文链接:https://www.haomeiwen.com/subject/tpzdcqtx.html