torchvision指南
本笔记引用自PyTorch中文文档
1. torchvision介绍
torchvision
包含了目前流行的数据集,模型结构和常用的图片转换工具。
1.1 torchvision.datasets
-
torchvision.datasets
中包含以下数据集:MNIST
COCO
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10
-
dataset
拥有以下API:__getitem__
,__len__
- 这些
datasets
都是torch.utils.data.Dataset
的子类
1.1.1 MNIST
-
dest.MNIST(root, train=True, transform=None, target_transform=None, download=False)
-download: 是否联网下载
1.1.2 COCO
- 图像标注
dest.CocoCaptions(root="dir where images are", annFile="json annotation file", [transform, target_transform, transform])
# sample
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print(Number of sample: ', len(cap))
img, target = cap[3]
print('Image size:', img.size())
print(target)
'''
output
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
'''
- 检测
- `dset.CocoDetection(root='dir where images are', annFile='json annotation file', [transform, target_transform])
1.1.3 ImageFolder
- 一个通用的数据加载器,数据集中数据以以下方式组织:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
-
dset.ImageFolder(root='root folder path', [transform, target_transform])
-
self.classes
: 用一个list保存类名 -
self.class_to_idx
: 类名对应的索引 -
self.imgs
: 保存(img_path, class)tuple的list
-
1.1.4 CIFAR
dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
1.2 torchvision.models
-
torchvision.models
模块的子模块中包含以下模型结构:AlexNet
VGG
ResNet
SqueezeNet
DenseNet
- 可使用预训练模型,如:
torchvision.models.alexnet(pretrained=False, **kwargs)
torchvision.models.resnet18(pretrained=False, **kwargs)
torchvision.models.resnet34(pretrained=False, ** kwargs)
torchvision.models.resnet50(pretrained=False, ** kwargs)
torchvision.models.resnet101(pretrained=False, ** kwargs)
torchvision.models.resnet152(pretrained=False, ** kwargs)
torchvision.models.vgg11(pretrained=False, ** kwargs)
torchvision.models.vgg11_bn(** kwargs)
torchvision.models.vgg13(pretrained=False, ** kwargs)
torchvision.models.vgg13_bn(** kwargs)
torchvision.models.vgg16(pretrained=False, ** kwargs)
torchvision.models.vgg16_bn(** kwargs)
torchvision.models.vgg19(pretrained=False, ** kwargs)
torchvision.models.vgg19_bn(** kwargs)
1.3 torchvision.transforms
- 对
PIL.Image
进行变换 - 使用
torchvision.transforms.Compose(transforms)
将多个transform
组合起来使用 -
transforms.CenterCrop(size)
:将给定PIL.Image
进行中心切割,得到给定的size
,size
可以是tuple
或Integer
-
transforms.RandomCrop(size, padding=0)
:切割中心点位置随机选取,size
可以是tuple
或Integer
-
transforms.RandomHorizontalFlip(p=0.5)
:随机水平翻转 -
transforms.RandomSizedCrop(size, interpolation=2)
:先随机切,再resize
成给定size
大小。 -
transforms.Pad(padding, fill=0)
:给所有边用给定的值填充。padding
:要填充多少像素 -
transforms.ToTensor()
:将一个取值范围是[0, 255]
的PIL.Image
或shape
为(H, W, C)
的numpy.ndarray
,转换成形状为[C, H, W]
,取值范围是[0, 1.0]
的torch.FloatTensor
-
transforms.Normalize(mean, std)
:给定均值与方差,正则化,即Normalized_image=(image-mean)/std
- 通用变换:使用
lambda
作为转换器,transforms.Lambda(lambda)
1.4 torchvision.utils
-
utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
给定4D-mini-batch Tensor
,形状为(B*C*H*W)
,或者一个a list of image
,做成一个size
为(B / nrow, nrow)
的子图集- normalize=True, 对图像像素归一化
- range=(min, max),min和max是数字,则min, max用来规范化
image
- scale_each=True, 每个图片独立规范化。
-
utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
将给定的Tensor
保存成image文件,如果是mini-batch tensor
,就用make-grid
做成子图集再保存。
网友评论