如何将我们准备好的数据放入模型中呢? Pytorch 给出的答案都在torch.utils.data 包中。
一、先看看所有的类
这个模块中方法并不多,所以让我们先全部列出来看看,看看名字猜猜功能。
- Class torch.utils.data.Dataset 一个抽象类, 所有其他类的数据集类都应该是它的子类。所有子类应该重载len和getitem,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
- Class torch.utils.data.DataLoader 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
-
Class torch.utils.data.IterableDataset
-
Class torch.utils.data.TensorDataset
-
Class torch.utils.data.ConcatDataset(datasets)
-
Class torch.utils.data.ChainDataset(datasets)
-
Class torch.utils.data.Subset(dataset, indices)
以上就是所有的类,之后的内容主要介绍Dataset和DatasetLoader这两个类,因为学会了这两个类,以后你可以按照任何你想的方式向模型中输入数据了。
除了以上的CLASS,torch.utils.data 包中还提供了一些的数据采样的类和方法。相信大家以前都应该用过sklearn的train_test_split(),以下的其中一个方法也提供了类似的功能。 -
torch.utils.data.random_split(dataset, lengths) 按照给定的长度将数据集划分成没有重叠的新数据集组合。
-
CLASStorch.utils.data.Sampler(data_source)
-
CLASStorch.utils.data.SequentialSampler(data_source)
-
CLASStorch.utils.data.RandomSampler(...)
-
torch.utils.data.SubsetRandomSampler(...)
-
CLASStorch.utils.data.WeightedRandomSampler(...)
-
CLASStorch.utils.data.BatchSampler(sampler, batch_size, drop_last)
-
CLASStorch.utils.data.distributed.DistributedSampler(...)
二、Dataset和DatasetLoader
一般情况下,使用Dataset和DatasetLoader两个类已经可以完成大部分的数据导入。首先来看Dataset类。
在此对象中,必须重写以下两个方法。
def __getitem__(self, index)
return index对应的一条数据,可以是一张图,可以是一句话,总之 记住,一条数据。
def __len__():
return 带训练数据的总长度, 如果是dataframe, 直接len(df)即可,或者在init的时候传入了长度,直接返回
接下来看DataLoader 类
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
关键的几个参数:
- dataset: 就是第一个介绍的Dataset, 实例化之后传入这里
- batch_size: 这个不多说了
- shuffle: 对于train_data, 一般选择true; 其他一般选择false
- sampler: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.
看看实例:
想到sklearn中提供了一些小数据集,使用鸢尾花(iris)的数据集:
def loaddata():
iris_data = datasets.load_iris()
return iris_data["data"], iris_data["target"]
class IrisDataset(Dataset):
def __init__(self,irisdata,target):
# 传入参数
# ndarray 类型的,可以是任何类型的
self.irisdata = irisdata
self.target = target
self.lens = len(irisdata)
def __getitem__(self, index):
# index是方法自带的参数,获取相应的第index条数据
return self.irisdata[index,:],self.target[index]
def __len__(self):
return self.lens
数据集就构架完成了,大家也可以通过DataFrame来处理数据。
然后结合DataLoader使用:
data,target = loaddata()
dataset_iris = IrisDataset(data,target)
train_loader = torch.utils.data.DataLoader(dataset_iris, batch_size=10, shuffle=True, num_workers=4)
for i, (input, target) in enumerate(tqdm.tqdm(train_loader)):
print(input.size())
# 在这之后就可以进行训练了
输出
三、random_split 介绍
pytorch 中 random_split可以将实现sklearn 的 train_test_split类似的功能,大家可能注意到了,在上面的例子中只有训练数据,一般还需要有test set和valid set。
那么我们用random_split来划分数据集吧:
data,target = loaddata()
dataset_iris = IrisDataset(data,target)
all_length = len(dataset_iris)
train_size = int(0.80 * all_length)
test_size = all_length - train_size
train_dataset,test_dataset = torch.utils.data.random_split(dataset_iris,[train_size,test_size])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=4)
到这里就已经分好了,不过还是建议先通过其他方法提前分好。为了使每次结果都相同,可以设置好seed。
网友评论