准备数据
1. torchvision内置的数据包(MNIST)(CIFAR10)
data_train = MNIST(root='./pytorch_data/mnist', download=False,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]
))
data_test = MNIST(root='./pytorch_data/mnist', train=False, download=False,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]))
pytorch的图像预处理包
pytorch中transform函数
transform Docs
2. ImageFolder-适用于ImageNet数据
the images should be arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
or
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
实验时TinyImagenet的结构为class/images/xxx.JPEG,这样也好用,正确收到class对应关系。
ImageFolder自动将classes对应成int型的index,可以通过.class_to_idx
查看
最终我的数据为TinyImagenet,位置为C:/Users/Cooper/PycharmProjects/TinyImagenet/train
。
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import alexnet
from torch.utils.data import DataLoader
transform = transforms.Compose([
# transforms.RandomCrop(227),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
trainset = ImageFolder('C:/Users/Cooper/PycharmProjects/TinyImagenet/train', transform=transform)
trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
# get some random training images
for i, (img, label) in enumerate(trainloader):
pass
print('finish')
通过for循环来看img和label是什么
img和label
可以看到img是图片tensor,label是相应类别的tensor,符合要求
3. PIL.Image - 单独读取图片
transform = transforms.Compose([
transforms.RandomResizedCrop(227),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open('C:/Users/Cooper/PycharmProjects/Imagenet/test/ILSVRC2010_test_00000008.JPEG').convert('RGB')
image = transform(image)
image = image.view(1, image.size(0), image.size(1), image.size(2))
注意Image.open()
中放的是文件而不是路径,如果放路径将得到error[Errno 13] Permission denied
注意图片要通过class net
,需要有四维batch_size
channels
size_len
size_width
,读入的单张图片没有batch_size
,所以用.view()
函数生成了第一维
Dataloader
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True)
data_test_loader = DataLoader(data_test, batch_size=256, shuffle=False)
Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.
将数据集和迭代器组合,并且提供迭代器DataloaderIter
- 迭代器:迭代器(iterator)有时又称游标(cursor),可在容器(container,例如链表或阵列)上遍访的接口,每一种容器型都必须提供自己的迭代器,用来访问容器的数据。
dataiter = iter(trainloader)
images, labels = dataiter.next()
这样dataiter就是一个DataloderIter
,可以用它来访问dataloader的数据。可用next()
函数来访问容器的下一条数据。参考
网友评论