准备工作
本篇文章参考自pytorch官方教程,即末尾参考的第一篇,略去了不必要的乱七八糟的matlab显示功能,保留最实用的数据加载功能。
先从这里下载并解压示例数据集。这里介绍如何创建一个dataloader去加载该文件夹内的数据集。
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
Dataset Class
torch.utils.data.Dataset
是一个抽象类,代表了一个数据集。自定义数据集的时候需要重写两个方法。
__len__
使得len(dataset)
可以返回dataset的大小
__getitem__
支持dataset[i]
可以取出第i个数据。
下面为我们的数据集创建一个dataset类,首先会在__init__
方法中读取csv文件,在__getitem__
方法中读取图片,这样可以节约内存,根据需要读取图片,而不是一次性加载图片到内存中。
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')
至此数据就可以从face_dataset中读取了。
变换
可以看到文件夹内的图片大小都不一致,大多数的网络都需要接受统一大小,所以需要对数据进行一些预处理,例如缩放,随机裁剪,转化成张量。
我们会将这些方法写道一个可调用的类中,而不是简单的函数中,如此一来变换的参数就不用每次调用都传递一次。所以我们需要在类中实现__call__
方法,有必要的话还要实现__init__
方法。
我们可以像下面这样调用。
tsfm = Transform(params)
transformed_sample = tsfm(sample)
像下面这样定义
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size.
If tuple, output is matched to output_size.
If int, smaller of image edges is matched to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h, left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}
组合变换
如果我们需要做很最多变换,就需要把这些类组合到一起。像下面这样
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
plt.show()
迭代器
下面把这些变换都结合到一起创建一个dataset。所有图片都从文件名中,变换在读取图片是生效,每一个变换都是随机的。
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
然而我们丢失了一些特征,比如数据的批大小,数据随机,多gpu并行处理。可以用dataloader来玩。
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
然后可以遍历dataloader,读取里面的数据。
还有一点没看完
## Afterword: torchvision
[Afterword: torchvision](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#afterword-torchvision)
In this tutorial, we have seen how to write and use datasets, transforms and dataloader. `torchvision` package provides some common datasets and transforms. You might not even have to write custom classes. One of the more generic datasets available in torchvision is `ImageFolder`. It assumes that images are organized in the following way:
<pre style="box-sizing: border-box; font-family: IBMPlexMono, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; font-size: 14px; margin-top: 0px; margin-bottom: 2.5rem; overflow: auto; display: block; color: rgb(33, 37, 41); padding: 1.375rem; background-color: rgb(243, 244, 247); white-space: pre-wrap; overflow-wrap: break-word;">root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
</pre>
where ‘ants’, ‘bees’ etc. are class labels. Similarly generic transforms which operate on `PIL.Image` like `RandomHorizontalFlip`, `Scale`, are also available. You can use these to write a dataloader like this:
<pre style="box-sizing: border-box; font-family: IBMPlexMono, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; font-size: 14px; margin-top: 0px; margin-bottom: 2.5rem; overflow: auto; display: block; color: rgb(33, 37, 41); padding: 1.375rem; background-color: rgb(243, 244, 247); white-space: pre-wrap; overflow-wrap: break-word;">import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)</pre>
网友评论