美文网首页
pytorch 函数DataLoader

pytorch 函数DataLoader

作者: 顾北向南 | 来源:发表于2019-10-04 13:08 被阅读0次

    参考链接:https://www.e-learn.cn/content/qita/850153
    本文仅作为学术分享,如果侵权,会删文处理

    • Dataset只负责数据的抽象,一次调用getitem只返回一个样本。前面提到过,在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助我们实现这些功能。
    • DataLoader的函数定义如下:
    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
    
    • dataset:加载的数据集(Dataset对象)
    • batch_size:batch size
    • shuffle::是否将数据打乱
    • sampler: 样本抽样,后续会详细介绍
    • num_workers:使用多进程加载的进程数,0代表不使用多进程
    • collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
    • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
    • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
    from torchvision.datasets import ImageFolder
    import torch
    from torchvision import transforms
    from torch.utils.data import DataLoader
    
    #加上transforms
    normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
    transform=transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
        normalize
    ])
    
    dataset=ImageFolder('data/dogcat_2/',transform=transform)
    
    #dataloader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它 或者 or batch_datas, batch_labels in dataloader:
    dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
    
    dataiter = iter(dataloader)
    imgs, labels = next(dataiter)
    print(imgs.size()) # batch_size, channel, height, weight
    #输出 torch.Size([3, 3, 224, 224])
    
    • 在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在_ getitem _函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
    '''
    在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
    '''
    from dataSet import *
    import random
    class NewDogCat(DogCat): # 继承前面实现的DogCat数据集
        def __getitem__(self, index):
            try:
                # 调用父类的获取函数,即 DogCat.__getitem__(self, index)
                return super(NewDogCat,self).__getitem__(index)
            except:
                #对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:
                new_index = random.randint(0, len(self) - 1)
                return self[new_index]
    
    from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
    from torch.utils.data import DataLoader
    def my_collate_fn(batch):
        '''
        batch中每个元素形如(data, label)
        '''
        # 过滤为None的数据
        batch = list(filter(lambda x:x[0] is not None, batch))
        if len(batch) == 0: return torch.Tensor()
        return default_collate(batch) # 用默认方式拼接过滤后的batch数据
    
    
    transform=transforms.Compose([
        transforms.Resize(224), #缩放图片,保持长宽比不变,最短边的长为224像素,
        transforms.CenterCrop(224), #从中间切出 224*224的图片
        transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
        transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至[-1,1]
    ])
    
    
    dataset = NewDogCat(root='data/dogcat_wrong/', transform=transform)
    
    #print(dataSet[11])
    dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1,shuffle=True)
    for batch_datas, batch_labels in dataloader:
        print(batch_datas.size(),batch_labels.size())
    

    相关文章

      网友评论

          本文标题:pytorch 函数DataLoader

          本文链接:https://www.haomeiwen.com/subject/hmfcyctx.html