1:torch.utils.data.random_split()划分数据集
torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)
随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。
**参数**
dataset (Dataset) – 要划分的数据集。
lengths (sequence) – 要划分的长度。
generator (Generator) – 用于随机排列的生成器。
import torch
import torchvision
# from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
# 准备数据集
from torch import nn
from torch.utils.data import DataLoader
# 定义训练的设备
device = torch.device("cuda")
#读取数据
data_transform = transforms.Compose([
transforms.Resize(size=(224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
])
full_dataset = ImageFolder(r'D:\PythonSpace\data\trainTest',transform = data_transform)
# length 数据集总长度
full_data_size = len(full_dataset)
print("总数据集的长度为:{}".format(full_data_size))
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
train_data_size = len(train_dataset)
test_data_size = len(test_dataset)
# 如果train_data_size=10, 训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
>>>
总数据集的长度为:244
训练数据集的长度为:195
测试数据集的长度为:49
网友评论