torch.utils.data
模块提供了一些用于数据加载和处理的工具,其中最常用的类和函数包括 Dataset, DataLoader, Sampler 以及相关的辅助工具。这些工具使得处理大型数据集以及在批处理、并行化和数据预处理等方面变得更加简便。
Dataset
-
Dataset
是一个抽象类,用户可以通过继承它来定义自己的数据集。需要实现__len__
和__getitem__
方法。 -
抽象类是一种不能被实例化的类,它通常作为其他类的基类,提供抽象方法的定义,而这些方法需要在具体的子类中实现。抽象类的主要作用是定义接口或提供框架,确保子类实现特定的方法,从而保证子类具有一致的接口和行为。
-
示例代码
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 创建数据集
dataset = MyDataset(data, labels)
dataset.data
dataset.labels
random_split
-
random_split
用于将数据集按比例随机划分成多个子集。 -
示例代码
from torch.utils.data import random_split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
DataLoader
-
DataLoader
是用于将数据集分成小批量,并提供自动化多线程数据加载的工具。常用参数包括 batch_size, shuffle, num_workers 等。 -
示例代码
from torch.utils.data import DataLoader
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
总结
先定义Dataset类创建数据集,然后random_split划分数据集,最后DataLoader常见train_loader/valid_loader/test_loader
网友评论