美文网首页pytorch
pytorch划分数据集的方法

pytorch划分数据集的方法

作者: 午字横 | 来源:发表于2022-11-25 20:24 被阅读0次

    1:torch.utils.data.random_split()划分数据集
    torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

    随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。
    **参数**
    dataset (Dataset) – 要划分的数据集。
    lengths (sequence) – 要划分的长度。
    generator (Generator) – 用于随机排列的生成器。

    import torch
    import torchvision
    # from torch.utils.tensorboard import SummaryWriter
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    # 准备数据集
    from torch import nn
    from torch.utils.data import DataLoader
    
    # 定义训练的设备
    device = torch.device("cuda")
    
    #读取数据
    data_transform = transforms.Compose([
        transforms.Resize(size=(224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
    ])
    full_dataset = ImageFolder(r'D:\PythonSpace\data\trainTest',transform = data_transform)
    
    # length 数据集总长度
    full_data_size = len(full_dataset)
    print("总数据集的长度为:{}".format(full_data_size))
    
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    train_data_size = len(train_dataset)
    test_data_size = len(test_dataset)
    # 如果train_data_size=10, 训练数据集的长度为:10
    print("训练数据集的长度为:{}".format(train_data_size))
    print("测试数据集的长度为:{}".format(test_data_size))
    
    
    >>>
    总数据集的长度为:244
    训练数据集的长度为:195
    测试数据集的长度为:49
    

    相关文章

      网友评论

        本文标题:pytorch划分数据集的方法

        本文链接:https://www.haomeiwen.com/subject/wvbzxdtx.html