美文网首页
Pytorch学习笔记(9) 通过DataSet、Dataset

Pytorch学习笔记(9) 通过DataSet、Dataset

作者: 银色尘埃010 | 来源:发表于2020-06-17 20:43 被阅读0次

    如何将我们准备好的数据放入模型中呢? Pytorch 给出的答案都在torch.utils.data 包中。

    一、先看看所有的类

    这个模块中方法并不多,所以让我们先全部列出来看看,看看名字猜猜功能。

    • Class torch.utils.data.Dataset 一个抽象类, 所有其他类的数据集类都应该是它的子类。所有子类应该重载lengetitem,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
    • Class torch.utils.data.DataLoader 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
    • Class torch.utils.data.IterableDataset

    • Class torch.utils.data.TensorDataset

    • Class torch.utils.data.ConcatDataset(datasets)

    • Class torch.utils.data.ChainDataset(datasets)

    • Class torch.utils.data.Subset(dataset, indices)
      以上就是所有的类,之后的内容主要介绍Dataset和DatasetLoader这两个类,因为学会了这两个类,以后你可以按照任何你想的方式向模型中输入数据了。
      除了以上的CLASS,torch.utils.data 包中还提供了一些的数据采样的类和方法。相信大家以前都应该用过sklearn的train_test_split(),以下的其中一个方法也提供了类似的功能。

    • torch.utils.data.random_split(dataset, lengths) 按照给定的长度将数据集划分成没有重叠的新数据集组合。

    • CLASStorch.utils.data.Sampler(data_source)

    • CLASStorch.utils.data.SequentialSampler(data_source)

    • CLASStorch.utils.data.RandomSampler(...)

    • torch.utils.data.SubsetRandomSampler(...)

    • CLASStorch.utils.data.WeightedRandomSampler(...)

    • CLASStorch.utils.data.BatchSampler(sampler, batch_size, drop_last)

    • CLASStorch.utils.data.distributed.DistributedSampler(...)

    二、Dataset和DatasetLoader

    一般情况下,使用Dataset和DatasetLoader两个类已经可以完成大部分的数据导入。首先来看Dataset类。
    在此对象中,必须重写以下两个方法。

    def __getitem__(self, index)
          return  index对应的一条数据,可以是一张图,可以是一句话,总之 记住,一条数据。
         
    def  __len__():
        return  带训练数据的总长度, 如果是dataframe, 直接len(df)即可,或者在init的时候传入了长度,直接返回
    

    接下来看DataLoader 类

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
    

    关键的几个参数:

    • dataset: 就是第一个介绍的Dataset, 实例化之后传入这里
    • batch_size: 这个不多说了
    • shuffle: 对于train_data, 一般选择true; 其他一般选择false
    • sampler: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.

    看看实例:
    想到sklearn中提供了一些小数据集,使用鸢尾花(iris)的数据集:

    def loaddata():
        iris_data = datasets.load_iris()
        return iris_data["data"], iris_data["target"]
    
    class IrisDataset(Dataset):
        def __init__(self,irisdata,target):
            #   传入参数
            #   ndarray 类型的,可以是任何类型的
            self.irisdata = irisdata
            self.target = target
            self.lens = len(irisdata)
    
        def __getitem__(self, index):
            # index是方法自带的参数,获取相应的第index条数据
            return self.irisdata[index,:],self.target[index]
    
        def __len__(self):
            return self.lens
    

    数据集就构架完成了,大家也可以通过DataFrame来处理数据。
    然后结合DataLoader使用:

    data,target = loaddata()
    dataset_iris = IrisDataset(data,target)
    train_loader = torch.utils.data.DataLoader(dataset_iris, batch_size=10,   shuffle=True, num_workers=4)
    
    for i, (input, target) in enumerate(tqdm.tqdm(train_loader)):
            print(input.size())
            # 在这之后就可以进行训练了
    
    输出

    三、random_split 介绍

    pytorch 中 random_split可以将实现sklearn 的 train_test_split类似的功能,大家可能注意到了,在上面的例子中只有训练数据,一般还需要有test set和valid set。
    那么我们用random_split来划分数据集吧:

        data,target = loaddata()
        dataset_iris = IrisDataset(data,target)
    
        all_length = len(dataset_iris)
        train_size = int(0.80 * all_length)
        test_size = all_length - train_size
    
        train_dataset,test_dataset = torch.utils.data.random_split(dataset_iris,[train_size,test_size])
    
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=4)
    
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=4)
    

    到这里就已经分好了,不过还是建议先通过其他方法提前分好。为了使每次结果都相同,可以设置好seed。

    相关文章

      网友评论

          本文标题:Pytorch学习笔记(9) 通过DataSet、Dataset

          本文链接:https://www.haomeiwen.com/subject/ushexktx.html