解决机器学习问题中很多的工作都是在处理数据。Pytorch提供许多工具,是的数据操作更加简便有用,使代码更加的易读。
Dataset class
Dataset位于torch.units.data.Dataset
,是一个抽象类,用于代表数据集。我们可以继承它,然后重写以下两个方法:
-
__len__
用于返回数据集长度len(dataset)
-
__getitem__
用于使用下标索引如dataset[i]
总结来说,我们通过继承Dataset
类,实现其方法后,将数据封装于此类中。如人脸数据集的示例。
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
Transforms
addon:torchvision包含三个模块:dataset
包含常见数据集,models
包含常见模型,tranforms
用于对影像进行变换。
Tranforms
位于torchvision.transforms
,即对数据集进行变换,如影像数据集的裁减,缩放等。
-
transforms.Compose
将若干transforms
操作集合起来,如
transforms.Compose([transforms.CenterCrop(10),transforms.ToTensor(),])
将CenterCrop
与ToTensor
结合起来。
小示例如:
import torchvision.transforms as transforms
In [6]: trans1 = transforms.Compose([transforms.ToTensor(),])
In [7]: img = cv2.imread('elephant.png')
In [8]: img1 = trans1(img)
In [9]: type(img1)
Out[9]: torch.Tensor
In [10]: type(img)
Out[10]: numpy.ndarray
In [11]: trans2 = transforms.Compose([transforms.ToPILImage(),])
In [12]: img2 = trans2(img)
In [13]: type(img2)
Out[13]: PIL.Image.Image
-
Transforms on PIL Image
。包含了诸如图片裁剪,变换等操作。 -
Ttransforms on torch.*Tensor
。线性变换,正则化等操作。 -
Conversion Transforms
。包含了向Tensor或者向PIL Image转换的transforms
。 -
Generic Transforms
。更加通用的转换方法。
DataLoader
即将标准的Dataset装载进DataLoader,从而实现mini batch,shuffle等便捷操作。【时间有限,下次更】
网友评论