在构建深度学习的过程中我们话大量的时间处理数据。一般我们会用到opencv,PIL,skimage等图像处理库。我们今天主要介绍torchvision图像处理库。
torch 中的数据
一般来说torch中的数据必须封装成torch.utils.data.Dataset
。
from torch.utils.data import Dataset
from os.path import join, exists, basename
from glob import glob
from torchvision import transforms as T
class DogCat(Dataset):
def __init__(self, folder, transform = T.ToTensor()):
self.images = glob(join(folder, '*.jpg'))
self.transform = transform
def __getitem__(self, item):
img = Image.open(self.images[item])
data = self.transform(img)
target = 1 if basename(self.images[item])).startswith('dog') else 0
return data, target
def __len__(self):
return len(self.images)
其中,transform
参数用来传递图像处理功能。一般我们用torchvision.transforms
来处理图像T.ToTensor()
表示把PIL
格式的图像处理成pytorch
是识别的数据。
这里需要注意的一点是,ToTensor()
会把图像从WxHxC
变成CxWxH
,同时默认把数据按照一定mean
和 std
进行归一化。
我们还可以对PIL
图片做更多的操作,如RandomCrop
、Affine
(放射变化)等, 同时还支持把这些操作按照一定的顺序进行组合,如下所示:
transform = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize([.5, .5, .5], [.5, .5, .5]) #这里表示mean=[R, G, B], std=[R, G, B]
])
-
T.Resize(416)(img)
表示图片的缩放(或者增大),把图像的最小边缩放为416。 -
T.CenterCrop(416)(img)
表示把图片从中剪裁,裁剪后的大小为416x416
网友评论