美文网首页
输入数据

输入数据

作者: _Cooper_ | 来源:发表于2018-07-01 22:06 被阅读65次

准备数据

1. torchvision内置的数据包(MNIST)(CIFAR10)

data_train = MNIST(root='./pytorch_data/mnist', download=False,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor()]
                   ))
data_test = MNIST(root='./pytorch_data/mnist', train=False, download=False,
                  transform=transforms.Compose([
                      transforms.Resize((32, 32)),
                      transforms.ToTensor()
                  ]))

pytorch的图像预处理包

pytorch中transform函数
transform Docs

2. ImageFolder-适用于ImageNet数据

doc on ImageFolder

the images should be arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
or
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

实验时TinyImagenet的结构为class/images/xxx.JPEG,这样也好用,正确收到class对应关系。
ImageFolder自动将classes对应成int型的index,可以通过.class_to_idx查看
最终我的数据为TinyImagenet,位置为C:/Users/Cooper/PycharmProjects/TinyImagenet/train

import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import alexnet
from torch.utils.data import DataLoader

transform = transforms.Compose([
    # transforms.RandomCrop(227),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

trainset = ImageFolder('C:/Users/Cooper/PycharmProjects/TinyImagenet/train', transform=transform)
trainloader = DataLoader(trainset, batch_size=256, shuffle=True)


# get some random training images
for i, (img, label) in enumerate(trainloader):
    pass

print('finish')

通过for循环来看img和label是什么


img和label

可以看到img是图片tensor,label是相应类别的tensor,符合要求

3. PIL.Image - 单独读取图片

transform = transforms.Compose([
    transforms.RandomResizedCrop(227),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open('C:/Users/Cooper/PycharmProjects/Imagenet/test/ILSVRC2010_test_00000008.JPEG').convert('RGB')
image = transform(image)
image = image.view(1, image.size(0), image.size(1), image.size(2))

注意Image.open()中放的是文件而不是路径,如果放路径将得到error[Errno 13] Permission denied
注意图片要通过class net,需要有四维batch_size channels size_len size_width,读入的单张图片没有batch_size,所以用.view()函数生成了第一维

Dataloader

data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True)
data_test_loader = DataLoader(data_test, batch_size=256, shuffle=False)

Dataloader Docs

Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.

将数据集和迭代器组合,并且提供迭代器DataloaderIter

  • 迭代器:迭代器(iterator)有时又称游标(cursor),可在容器(container,例如链表或阵列)上遍访的接口,每一种容器型都必须提供自己的迭代器,用来访问容器的数据。
dataiter = iter(trainloader)
images, labels = dataiter.next()

这样dataiter就是一个DataloderIter,可以用它来访问dataloader的数据。可用next()函数来访问容器的下一条数据。参考

相关文章

网友评论

      本文标题:输入数据

      本文链接:https://www.haomeiwen.com/subject/bwbhuftx.html