transforms模块详解
transforms是torchvision中的一个重要模块,它是Pytorch的图像预处理包,包含了很多种对图像数据进行变换的函数,这些都是我们加载训练数据步骤中必不可少的。比较常见的是下面的这部分代码:
data_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Compose方法是将多种变换组合在一起。上述对data_transforms进行了四种变换,前两个是对PILImage进行的,分别对其进行随机大小和随机宽高比的裁剪,之后resize到指定大小224,以及对原始图像进行随机的水平翻转;
第三个transforms.ToTensor()将PILImage的转变为torch.FloatTensor的数据形式;最后一个Normalize则是对tensor进行的,不要问这些数值是怎么来的它们都是从ImageNet训练模型中总结出来的参数。下面需要着重强调一点是多种组合变换有一定的先后顺序,处理PILImage的变换方法(大多数方法)都需要放在ToTensor方法之前,而处理tensor的方法如上面的Normalize方法则要放在ToTensor方法之后。
transforms中的一些函数
- ToTensor:convert a PIL image to tensor (HWC) in range [0,255] to a torch.Tensor(CHW) in the range [0.0,1.0]
- Normalize:Normalized an tensor image with mean and standard deviation
即:用给定的均值和标准差分别对每个通道的数据进行正则化。具体来说,给定均值(M1,M2,.....,Mn),给定标准差(S1,S2,...,Sn),其中n是通道数(一般是3),对每个通道进行如下的操作:
output[channel]=(input[channel]-mean[channel])/std[channel]
例如:原来的tensor是三个维度,值在[0,1]之间,经过变换之后得到[-1,1]
计算如下:
((0,1)-0.5)/0.5=(-1,1)
- ToPILImage:convert a tensor to PILImage
transforms针对PILImage的操作还有很多
- 1.CenterCrop:在图片的中间区域进行裁剪
- 2.RandomCrop:在一个随机的位置进行裁剪
- 3.RandomHorizonFlip:以0.5的概率水平翻转给定的PIL图像
- 4.RandomVerticalFlip:以0.5的概率垂直翻转给定的PIL图像
- 5.RandomResizedFlip:将PIL图像裁剪成任意大小纵横比
- 6.Grayscale:将图像转换为灰度图像
- 7.RandomGrayscale:将图像以一定的概率转换为灰度图像
- 8.FiceCrop:把图像裁剪为四个角和一个中心
- 9.Pad:填充
- 10.ColorJitter:随机改变图像的亮度对比度和饱和度
网友评论