美文网首页
输入数据

输入数据

作者: _Cooper_ | 来源:发表于2018-07-01 22:06 被阅读65次

    准备数据

    1. torchvision内置的数据包(MNIST)(CIFAR10)

    data_train = MNIST(root='./pytorch_data/mnist', download=False,
                       transform=transforms.Compose([
                           transforms.Resize((32, 32)),
                           transforms.ToTensor()]
                       ))
    data_test = MNIST(root='./pytorch_data/mnist', train=False, download=False,
                      transform=transforms.Compose([
                          transforms.Resize((32, 32)),
                          transforms.ToTensor()
                      ]))
    

    pytorch的图像预处理包

    pytorch中transform函数
    transform Docs

    2. ImageFolder-适用于ImageNet数据

    doc on ImageFolder

    the images should be arranged in this way:
    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png
    or
    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png

    实验时TinyImagenet的结构为class/images/xxx.JPEG,这样也好用,正确收到class对应关系。
    ImageFolder自动将classes对应成int型的index,可以通过.class_to_idx查看
    最终我的数据为TinyImagenet,位置为C:/Users/Cooper/PycharmProjects/TinyImagenet/train

    import torch.nn as nn
    import torchvision.transforms as transforms
    from torchvision.datasets import ImageFolder
    from torchvision.models import alexnet
    from torch.utils.data import DataLoader
    
    transform = transforms.Compose([
        # transforms.RandomCrop(227),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ])
    
    trainset = ImageFolder('C:/Users/Cooper/PycharmProjects/TinyImagenet/train', transform=transform)
    trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
    
    
    # get some random training images
    for i, (img, label) in enumerate(trainloader):
        pass
    
    print('finish')
    
    

    通过for循环来看img和label是什么


    img和label

    可以看到img是图片tensor,label是相应类别的tensor,符合要求

    3. PIL.Image - 单独读取图片

    transform = transforms.Compose([
        transforms.RandomResizedCrop(227),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open('C:/Users/Cooper/PycharmProjects/Imagenet/test/ILSVRC2010_test_00000008.JPEG').convert('RGB')
    image = transform(image)
    image = image.view(1, image.size(0), image.size(1), image.size(2))
    

    注意Image.open()中放的是文件而不是路径,如果放路径将得到error[Errno 13] Permission denied
    注意图片要通过class net,需要有四维batch_size channels size_len size_width,读入的单张图片没有batch_size,所以用.view()函数生成了第一维

    Dataloader

    data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True)
    data_test_loader = DataLoader(data_test, batch_size=256, shuffle=False)
    

    Dataloader Docs

    Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.

    将数据集和迭代器组合,并且提供迭代器DataloaderIter

    • 迭代器:迭代器(iterator)有时又称游标(cursor),可在容器(container,例如链表或阵列)上遍访的接口,每一种容器型都必须提供自己的迭代器,用来访问容器的数据。
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    

    这样dataiter就是一个DataloderIter,可以用它来访问dataloader的数据。可用next()函数来访问容器的下一条数据。参考

    相关文章

      网友评论

          本文标题:输入数据

          本文链接:https://www.haomeiwen.com/subject/bwbhuftx.html