美文网首页
【深度学习DL-PyTorch】七、加载图像数据(Loading

【深度学习DL-PyTorch】七、加载图像数据(Loading

作者: ChiangCMBA | 来源:发表于2019-11-08 21:36 被阅读0次
    import torch
    from torchvision import datasets, transforms
    

    加载图像的最简单方式是使用 torchvisiondatasets.ImageFolder文档)。使用 ImageFolder,就是这样:

    dataset = datasets.ImageFolder('path/to/data', transform=transform)
    

    其中 'path/to/data' 是通往数据目录的文件路径,transform 是用 torchvision 中的 transforms 模块构建的处理步骤列表。ImageFolder 中的文件和目录应按以下格式构建:

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png
    
    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png
    

    每个图像类别都有各自存储图像的目录(catdog)。然后使用从目录名中提取的类别标记图像。图像 123.png 将采用类别标签 cat。你可以从此页面下载已经采用此结构的数据集。在其中已被拆分成了训练集和测试集。

    一、 转换

    使用 ImageFolder 加载数据时,你需要定义转换。例如,图像的尺寸不相同,但是我们需要将它们变成统一尺寸,才能用于训练模型。你可以使用 transforms.Resize() 调整尺寸或使用 transforms.CenterCrop()transforms.RandomResizedCrop() 等裁剪图像。我们还需要使用 transforms.ToTensor() 将图像转换为 PyTorch 张量。通常,你将使用 transforms.Compose()来将这些转换结合到一条流水线中,这条流水线接收包含转换的列表,并按顺序运行。流程大概为缩放、裁剪,然后转换为张量:

    transform = transforms.Compose([transforms.Resize(255),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor()])
    

    二、 数据加载器

    加载 ImageFolder 后,你需要将其传入 DataLoaderDataLoader 接受数据集(例如要从 ImageFolder 获得的数据集),并返回批次图像和相应的标签。你可以设置各种参数,例如批次大小,或者在每个周期之后是否重排数据。

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    

    dataloader 是一个生成器。要从中获取数据,你需要遍历它,或将它转换成迭代器并调用 next()

    # Looping through it, get a batch on each loop 
    for images, labels in dataloader:
        pass
    
    # Get one batch
    images, labels = next(iter(dataloader))
    

    三、 数据增强

    训练神经网络的一个常见策略是在输入数据本身里引入随机性。例如,你可以在训练过程中随机地旋转、翻转、缩放和/或裁剪图像。这样一来,你的神经网络在处理位置、大小、方向不同的相同图像时,可以更好地进行泛化。
    要随机旋转、缩放、裁剪图像,然后翻转图像,你需要如下所示地定义转换:

    train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                           transforms.RandomResizedCrop(224),
                                           transforms.RandomHorizontalFlip(),
                                           transforms.ToTensor(),
                                           transforms.Normalize([0.5, 0.5, 0.5], 
                                                                [0.5, 0.5, 0.5])])
    

    另外,还需要使用 transforms.Normalize 标准化图像。传入均值和标准偏差列表,然后标准化颜色通道。

    减去 mean 使数据以 0 居中,除以 std 使值位于 -1 到 1 之间。标准化有助于神经网络使权重接近 0,这能使反向传播更为稳定。不标准化的话,网络往往会学习失败。

    你可以在此处查看可用的转换列表。测试时,不能改变图像(但是需要以同一方式标准化)。因此,在验证/测试图像时,通常只能调整大小和裁剪图像。

    %matplotlib inline
    %config InlineBackend.figure_format = 'retina'
    
    import matplotlib.pyplot as plt
    
    import torch
    from torchvision import datasets, transforms
    
    data_dir = 'Cat_Dog_data'
    
    # TODO: Define transforms for the training data and testing data
    train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                           transforms.RandomResizedCrop(224),
                                           transforms.RandomHorizontalFlip(),
                                           transforms.ToTensor(),
                                           transforms.Normalize([0.5, 0.5, 0.5],
                                                                [0.5, 0.5, 0.5])])
    
    # test_transforms = transforms.Compose([transforms.RandomRotation(30),
    #                                       transforms.RandomResizedCrop(224),
    #                                       transforms.RandomHorizontalFlip(),
    #                                       transforms.ToTensor(),
    #                                       transforms.Normalize([0.5, 0.5, 0.5],
    #                                                            [0.5, 0.5, 0.5])])
    
    #测试时,不能改变图像(但是需要以同一方式标准化)。因此,在验证/测试图像时,通常只能调整大小和裁#剪图像。
    test_transforms = transforms.Compose([transforms.Resize(255),
                                          transforms.CenterCrop(224),
                                          transforms.ToTensor()])
    
    # Pass transforms in here, then run the next cell to see how the transforms look
    train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
    test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)
    
    trainloader = torch.utils.data.DataLoader(train_data, batch_size=32)
    testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
    
    
    # change this to the trainloader or testloader 
    data_iter = iter(testloader)
    
    images, labels = next(data_iter)
    fig, axes = plt.subplots(figsize=(10,4), ncols=4)
    for ii in range(4):
        ax = axes[ii]
        helper.imshow(images[ii], ax=ax, normalize=False)
    

    相关文章

      网友评论

          本文标题:【深度学习DL-PyTorch】七、加载图像数据(Loading

          本文链接:https://www.haomeiwen.com/subject/chtckctx.html