美文网首页
Pytorch笔记2-pytorch的数据集构建和相关操作

Pytorch笔记2-pytorch的数据集构建和相关操作

作者: 江湾青年 | 来源:发表于2024-07-10 11:27 被阅读0次

    torch.utils.data模块提供了一些用于数据加载和处理的工具,其中最常用的类和函数包括 Dataset, DataLoader, Sampler 以及相关的辅助工具。这些工具使得处理大型数据集以及在批处理、并行化和数据预处理等方面变得更加简便。

    Dataset

    • Dataset是一个抽象类,用户可以通过继承它来定义自己的数据集。需要实现 __len____getitem__ 方法。

    • 抽象类是一种不能被实例化的类,它通常作为其他类的基类,提供抽象方法的定义,而这些方法需要在具体的子类中实现。抽象类的主要作用是定义接口或提供框架,确保子类实现特定的方法,从而保证子类具有一致的接口和行为。

    • 示例代码

    from torch.utils.data import Dataset
    
    class MyDataset(Dataset):
        def __init__(self, data, labels):
            self.data = data
            self.labels = labels
    
        def __len__(self):
            return len(self.data)
    
        def __getitem__(self, idx):
            return self.data[idx], self.labels[idx]
    
    # 创建数据集
    dataset = MyDataset(data, labels)
    dataset.data
    dataset.labels
    

    random_split

    • random_split用于将数据集按比例随机划分成多个子集。

    • 示例代码

    from torch.utils.data import random_split
    
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    

    DataLoader

    • DataLoader 是用于将数据集分成小批量,并提供自动化多线程数据加载的工具。常用参数包括 batch_size, shuffle, num_workers 等。

    • 示例代码

    from torch.utils.data import DataLoader
    
    dataset = MyDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
    

    总结

    先定义Dataset类创建数据集,然后random_split划分数据集,最后DataLoader常见train_loader/valid_loader/test_loader

    相关文章

      网友评论

          本文标题:Pytorch笔记2-pytorch的数据集构建和相关操作

          本文链接:https://www.haomeiwen.com/subject/nxezcjtx.html