美文网首页pytorch我爱编程
Pytorch学习(3): 常用工具模块简介

Pytorch学习(3): 常用工具模块简介

作者: 月牙眼的楼下小黑 | 来源:发表于2018-04-04 22:16 被阅读64次

作 者: 月牙眼的楼下小黑
联 系zlf111@mail.ustc.edu.cn
声 明: 欢迎转载本文中的图片或文字,请说明出处


参考资料:

[1].PyTorch常用工具模块

1 数据处理

import torch 
from torch.utils import data
import os 
from PIL import Image
import numpy as np 
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()

1.1 数据加载

Pytorch 中, 数据加载可通过自定义一个继承Dataset类的数据集对象, 并实现两个方法:

  • __ getitem __ : 返回一条数据
  • __ len __ : 返回样本数量
class DogCat(data.Dataset):
    def __init__(self, root):
        imgs = os.listdir(root)                               # root:图片所在文件夹路径
        self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        if 'dog' in img_path.split('/')[-1]:
            label =1
        else: 
            label = 0
        pil_img= Image.open(img_path)            # 利用 python 图像处理标准库的 open 方法打开图片
        array = np.asarray(pil_img)              # 将 PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]
        #array = np.transpose(array, (2, 0, 1))  # 将 channel last 形式转化成  channel first 形式:[channel, height, width]
        data = torch.from_numpy(array)           # 将 np.ndarray  转化为 Tensor 形式
        return data, label
    
    def __len__(self):
        return len(self.imgs)

补充: 对于三维矩阵的转置, 如 a.transpose(2,0,1), 意思是原矩阵a(aix 0, aix 1, aix 2) 处的值,现在成为了转置后矩阵 (aix 2, aix 0 , aix 1)处的值。

in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs')
in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img)                    # 若为 channel last 形式的 tensor, 可用 matplotlib 中 imshow() 方法
print(label, img.size(), img.float().mean())
out:
1 torch.Size([500, 282, 3]) 169.23073522458628
in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img)                    # 若为 channel first 形式的 tensor, 可用 transforms 中 的 ToPILImage() 方法
print(label, img.size(), img.float().mean())

在前面文章中提到过:ToPILImage 可以将

  • shape(C,H,W)Tensor
  • shape(H,W,C)numpy.ndarray

转化成PIL.Image,值不变,方便可视化。注意到 它只能转变channel first形式的Tensor 。而在上面的__getitem__中,array = np.asarray(pil_img)PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]。 所以如果我们要使用 ToPILImage 方法显示图片,在将 PIL.image 转化为 np. ndarray 形式后,还需要利用转置方法将 channel last 形式改成 channel first 形式:[channel, height, width]

1.2 数据预处理

torchvision.transforms模块提供了对 PILImage对象和Tensor对象的常用操作。

PILImage的操作包括:

  • Scale: 调整图片尺寸,长宽比保持不变
  • CenterCropRandomCropRandomSizeCrop:裁剪图片
  • Pad
  • ToTensor: 将PILImage对象转化成channel firstTensor并归一至[0,1]

Tensor的操作包括:

  • Normalize: 标准化,减均值,除以标准差
  • ToPlLImage:将Tensor转化为PILImage对象
in:
trans = transforms.Resize((100,100))
image = Image.open('./dog.1.jpg')
print(image.size)
image = trans(image)
print(image.size)
out:
(327, 499)
(100, 100)

如果要对图片进行多个操作, 可通过Compose方法将这些操作拼接起来。

in:
transform = transforms.Compose([
    transforms.Resize(224),           # 缩放图片,保持长宽比不变,最短边为224像素
    transforms.CenterCrop(224),   # 从图片中间切出 224x224 的图片
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [1, 1, 1])
])

class DogCat(data.Dataset):
    def __init__(self, root, transforms = None):
        imgs = os.listdir(root)                                                # root:图片所在文件夹路径
        self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
        self.transforms = transforms
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        if 'dog' in img_path.split('/')[-1]:
            label =1
        else: 
            label = 0
        data = Image.open(img_path)           # 利用 python 图像处理标准库的 open 方法打开图片
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)
in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs', transforms = transform)
img,label= dataset[0]
print(img.size())
show(img)
out:
torch.Size([3, 224, 224])

1.3 ImageFolder

torchvision预先实现了常用的DataSet,如CIFAR-10, 可通过 torchvision.datasets.CIFAR10来调用。这里介绍一个经常使用的 DataSet——ImageFolder. ImageFolder 假设所有文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名, ImageFolder会根据文件夹名顺序自动生成 label, 可以通过 class_to_idx查看 label 和 文件夹名的映射关系。其构造函数如下:

ImageFolder(root, transform = None, target_transform = None, Loader = default_loader)
  • root :文件夹路径
  • transform: 对PILImage的转化操作, transform的输入是loader的返回对象
  • target_transform: 对 label 的转化
  • loader : 读取图片函数,默认为读取RGB格式的 PILImage对象
in:
from torchvision.datasets import ImageFolder
dataset = ImageFolder('./myDataSet/dog_cat_data/train')
in:
dataset.class_to_idx
out:
{'cats': 0, 'dogs': 1}
in:
# 此时还没有任何 transform, 返回的是 PILImage 对象
# 第一维指示第几张图片,第二维为 1 返回 label, 为 0 返回 图片数据
print(dataset[0][1]) 
dataset[0][0]
out:
0

1.4 DataLoader

调用DataSet中的__getitem__只返回一个样本,而我们需要batch wise trainingPytorch提供了 DataLoader帮助我们实现这些功能。其构造函数如下:

DataLoader(dataset, 
           batch_size=1, 
           shuffle = False,
           sample =None,
           sampler = None,
           num_workers =0,
           collate_fn = default_collate, 
           pin_memory =False,
           drop_last = False)
  • dataset: 加载的数据集对象(DataSet对象)
  • shuffle: 是否将数据打乱
  • sampler: 样本抽样
  • num_workers: 使用多进程加载的进程数
  • collate_fn: 如何将多个样本数据拼接成一个patch
  • pin_memory: 是否将数据保存在pin memory, pin memory中的数据转到gpu会快一些
  • drop_last: 当datast中的数据个数不是batchsize的整数倍, drop_lastTrue会将多出来不足一个batch的数据丢弃。

2. torchvision

torchvisionPytorch 团队开发的独立于 Pytorch的视觉工具包,通过pip install torchvision安装,主要包含三部分:

  • models:提供一系列经典已经预训练好的模型,包括AlexNet, VGG, ResNet, Inception
  • datasets: 提供常用的数据集对象(DataSet对象), 包括MNIST, CIFAR10/100,ImageNet, COCO
  • transforms: 提供常用数据预处理工作,主要包括对 TensorPILImage对象的操作。

3. 可视化工具 Visdom

Visdom可以创造、组织和共享多种数据的可视化,包括数值、图像、文本、视频, 支持 PytorchNumpy。.

  • Visdom的安装: pip install visdom
  • Visdom的启动: python -m visdom.server, 打开浏览器输入:http://localhost:8097, 8097 是默认端口号。

Visdom中的两个重要概念:

  • env: 环境。不同用户、不同程序一般使用不同 env. 不同 env 相互独立,互不影响。使用时不指定 env ,则默认使用main
  • pane: 窗格。 一个env中可以有多个不同的pane, 每个pane可视化或记录某一信息,可以拖动、缩放、保存或关闭
In:
import visdom
vis = visdom.Visdom(env=u'test1')    # 构建一个客户端对象,创建一个名为' test1' 的 env 
x = torch.arange(1, 30, 0.01)
y = torch.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})   #  win 是 pane 名字,opts 设置 pane 格式,如 title, xlabel,ylabel

vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'}) 中, win 参数指定 pane 名字, 如果不指定,visdom将自动分配一个新的pane. 如果两次操做指定的win名字一样,新操作将覆盖当前 pane 的内容。如在 上面的 pane中画 y = x 函数,原来的 y = sin(x) 将被覆盖。
In:
y = x
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=x'})

如果不想覆盖原图,可以使用updateTrace方法,如:
y = x + 1
vis.updateTrace(X=x, Y=y, win='sinx', name='this is a new Trace')

未完待续

相关文章

网友评论

  • IT人故事会:做开发很累,还的学习,之前你这个我也碰到过,但是没记录谢谢了
    月牙眼的楼下小黑:@IT人故事会 不不不,你还是很厉害的,视野广阔,啥都碰过,而且记忆不错,两次评论一模一样,严谨精准,跟机器人一样

本文标题:Pytorch学习(3): 常用工具模块简介

本文链接:https://www.haomeiwen.com/subject/wvblhftx.html