美文网首页
pytorch数据集相关操作

pytorch数据集相关操作

作者: random_walk | 来源:发表于2018-12-14 15:47 被阅读0次

常用数据集读取

对于常用的数据集,可以通过torchvision.datasets读取,torchvision.datasetstorch.utils.data.Dataset的一个实现,因此由_len___getitem__。常见的数据集如COCO,MNIST,CIFAR都可以在这里面找到,全部支持数据集见链接
API一般如下所示,以MNIST为例,其余只需要修改名字即可,略有不同:

class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)

参数选项:

  • root:存放数据集的位置
  • train:训练集还是测试集
  • transform:数据预处理
  • target_transform:标注的预处理
  • download:是否下载,如果已经在root存在,可以不下载
    比如以CIFAR10为例:
cifar = torchvision.datasets.CIFAR10(root = "../data/", train= True, download = True)

torch.utils.data.Dataset

Pytorch中数据集被抽象为一个抽象类torch.utils.data.Dataset,所有的数据集都应该继承这个类,且重写以下两个方法:

  • __len__:数据集的长度
  • __getitem__:返回一条数据或一个样本。表示怎么读取数据。
    比如
    def __getitem__(self, index):
        img_path, label = self.data[index].img_path, self.data[index].label
        img = Image.open(img_path)

        return img, label

此外还可以添加预处理操作。

torch.utils.data.DataLoader

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。API定义如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

有以下常用参数:

  • dataset: 需要加载的数据集
  • batch_size :batch大小
  • shuffle :是否打乱顺序
  • num_workers:多进程
    全部参数见链接

torchvision.transforms

我们可以用transforms进行以下常见操作:

  • PIL.Image/numpy.ndarray与Tensor的相互转化;
  • 归一化;
  • PIL.Image进行裁剪、缩放等操作。
    class torchvision.transforms.Compose(transforms):输入一个 transform 列表,将多个 transform 组合使用。
    因此一般如下使用:
transforms.Compose([
                    transforms.CenterCrop(10),#将给定的 PIL.Image 进行中心切割,得到给定的 size,size 可以是 tuple,(target_height, target_width)。size 也可以是一个 Integer,在这种情况下,切出来的图片形状是正方形。
                    transforms.Scale(12),
                    transforms.RandomCrop(size, padding=0)#切割中心点的位置随机选取
                    transfroms.ToTensor()#把一个取值范围是 [0, 255] 的 PIL.Image 或者 shape 为 (H, W, C) 的 numpy.ndarray,转换为形状为 (C, H, W),取值范围是 [0, 1] 的 torch.FloatTensor
                    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))#规范化到[-1,1]
])

全部操作见链接

数据集定义和加载流程

定义预处理:transform = transforms.Compose()
定义数据集:dataset = myDataSet()#自己定义或者直接从torchvision中加载
加载数据集:data = DataLoader(dataset)

相关文章

网友评论

      本文标题:pytorch数据集相关操作

      本文链接:https://www.haomeiwen.com/subject/vbcxqqtx.html