美文网首页pytorch我爱编程
Pytorch学习(3): 常用工具模块简介

Pytorch学习(3): 常用工具模块简介

作者: 月牙眼的楼下小黑 | 来源:发表于2018-04-04 22:16 被阅读64次

    作 者: 月牙眼的楼下小黑
    联 系zlf111@mail.ustc.edu.cn
    声 明: 欢迎转载本文中的图片或文字,请说明出处


    参考资料:

    [1].PyTorch常用工具模块

    1 数据处理

    import torch 
    from torch.utils import data
    import os 
    from PIL import Image
    import numpy as np 
    import matplotlib.pyplot as plt
    import torchvision.transforms as transforms
    from torchvision.transforms import ToPILImage
    show = ToPILImage()
    

    1.1 数据加载

    Pytorch 中, 数据加载可通过自定义一个继承Dataset类的数据集对象, 并实现两个方法:

    • __ getitem __ : 返回一条数据
    • __ len __ : 返回样本数量
    class DogCat(data.Dataset):
        def __init__(self, root):
            imgs = os.listdir(root)                               # root:图片所在文件夹路径
            self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
            
        def __getitem__(self, index):
            img_path = self.imgs[index]
            if 'dog' in img_path.split('/')[-1]:
                label =1
            else: 
                label = 0
            pil_img= Image.open(img_path)            # 利用 python 图像处理标准库的 open 方法打开图片
            array = np.asarray(pil_img)              # 将 PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]
            #array = np.transpose(array, (2, 0, 1))  # 将 channel last 形式转化成  channel first 形式:[channel, height, width]
            data = torch.from_numpy(array)           # 将 np.ndarray  转化为 Tensor 形式
            return data, label
        
        def __len__(self):
            return len(self.imgs)
    

    补充: 对于三维矩阵的转置, 如 a.transpose(2,0,1), 意思是原矩阵a(aix 0, aix 1, aix 2) 处的值,现在成为了转置后矩阵 (aix 2, aix 0 , aix 1)处的值。

    in:
    dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs')
    
    in:
    # 显示第一张图片
    img,label= dataset[0]
    plt.imshow(img)                    # 若为 channel last 形式的 tensor, 可用 matplotlib 中 imshow() 方法
    print(label, img.size(), img.float().mean())
    
    out:
    1 torch.Size([500, 282, 3]) 169.23073522458628
    
    in:
    # 显示第一张图片
    img,label= dataset[0]
    plt.imshow(img)                    # 若为 channel first 形式的 tensor, 可用 transforms 中 的 ToPILImage() 方法
    print(label, img.size(), img.float().mean())
    

    在前面文章中提到过:ToPILImage 可以将

    • shape(C,H,W)Tensor
    • shape(H,W,C)numpy.ndarray

    转化成PIL.Image,值不变,方便可视化。注意到 它只能转变channel first形式的Tensor 。而在上面的__getitem__中,array = np.asarray(pil_img)PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]。 所以如果我们要使用 ToPILImage 方法显示图片,在将 PIL.image 转化为 np. ndarray 形式后,还需要利用转置方法将 channel last 形式改成 channel first 形式:[channel, height, width]

    1.2 数据预处理

    torchvision.transforms模块提供了对 PILImage对象和Tensor对象的常用操作。

    PILImage的操作包括:

    • Scale: 调整图片尺寸,长宽比保持不变
    • CenterCropRandomCropRandomSizeCrop:裁剪图片
    • Pad
    • ToTensor: 将PILImage对象转化成channel firstTensor并归一至[0,1]

    Tensor的操作包括:

    • Normalize: 标准化,减均值,除以标准差
    • ToPlLImage:将Tensor转化为PILImage对象
    in:
    trans = transforms.Resize((100,100))
    image = Image.open('./dog.1.jpg')
    print(image.size)
    image = trans(image)
    print(image.size)
    
    out:
    (327, 499)
    (100, 100)
    

    如果要对图片进行多个操作, 可通过Compose方法将这些操作拼接起来。

    in:
    transform = transforms.Compose([
        transforms.Resize(224),           # 缩放图片,保持长宽比不变,最短边为224像素
        transforms.CenterCrop(224),   # 从图片中间切出 224x224 的图片
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [1, 1, 1])
    ])
    
    class DogCat(data.Dataset):
        def __init__(self, root, transforms = None):
            imgs = os.listdir(root)                                                # root:图片所在文件夹路径
            self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
            self.transforms = transforms
            
        def __getitem__(self, index):
            img_path = self.imgs[index]
            if 'dog' in img_path.split('/')[-1]:
                label =1
            else: 
                label = 0
            data = Image.open(img_path)           # 利用 python 图像处理标准库的 open 方法打开图片
            if self.transforms:
                data = self.transforms(data)
            return data, label
        
        def __len__(self):
            return len(self.imgs)
    
    in:
    dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs', transforms = transform)
    img,label= dataset[0]
    print(img.size())
    show(img)
    
    out:
    torch.Size([3, 224, 224])
    

    1.3 ImageFolder

    torchvision预先实现了常用的DataSet,如CIFAR-10, 可通过 torchvision.datasets.CIFAR10来调用。这里介绍一个经常使用的 DataSet——ImageFolder. ImageFolder 假设所有文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名, ImageFolder会根据文件夹名顺序自动生成 label, 可以通过 class_to_idx查看 label 和 文件夹名的映射关系。其构造函数如下:

    ImageFolder(root, transform = None, target_transform = None, Loader = default_loader)
    
    • root :文件夹路径
    • transform: 对PILImage的转化操作, transform的输入是loader的返回对象
    • target_transform: 对 label 的转化
    • loader : 读取图片函数,默认为读取RGB格式的 PILImage对象
    in:
    from torchvision.datasets import ImageFolder
    dataset = ImageFolder('./myDataSet/dog_cat_data/train')
    
    in:
    dataset.class_to_idx
    
    out:
    {'cats': 0, 'dogs': 1}
    
    in:
    # 此时还没有任何 transform, 返回的是 PILImage 对象
    # 第一维指示第几张图片,第二维为 1 返回 label, 为 0 返回 图片数据
    print(dataset[0][1]) 
    dataset[0][0]
    
    out:
    0
    

    1.4 DataLoader

    调用DataSet中的__getitem__只返回一个样本,而我们需要batch wise trainingPytorch提供了 DataLoader帮助我们实现这些功能。其构造函数如下:

    DataLoader(dataset, 
               batch_size=1, 
               shuffle = False,
               sample =None,
               sampler = None,
               num_workers =0,
               collate_fn = default_collate, 
               pin_memory =False,
               drop_last = False)
    
    • dataset: 加载的数据集对象(DataSet对象)
    • shuffle: 是否将数据打乱
    • sampler: 样本抽样
    • num_workers: 使用多进程加载的进程数
    • collate_fn: 如何将多个样本数据拼接成一个patch
    • pin_memory: 是否将数据保存在pin memory, pin memory中的数据转到gpu会快一些
    • drop_last: 当datast中的数据个数不是batchsize的整数倍, drop_lastTrue会将多出来不足一个batch的数据丢弃。

    2. torchvision

    torchvisionPytorch 团队开发的独立于 Pytorch的视觉工具包,通过pip install torchvision安装,主要包含三部分:

    • models:提供一系列经典已经预训练好的模型,包括AlexNet, VGG, ResNet, Inception
    • datasets: 提供常用的数据集对象(DataSet对象), 包括MNIST, CIFAR10/100,ImageNet, COCO
    • transforms: 提供常用数据预处理工作,主要包括对 TensorPILImage对象的操作。

    3. 可视化工具 Visdom

    Visdom可以创造、组织和共享多种数据的可视化,包括数值、图像、文本、视频, 支持 PytorchNumpy。.

    • Visdom的安装: pip install visdom
    • Visdom的启动: python -m visdom.server, 打开浏览器输入:http://localhost:8097, 8097 是默认端口号。

    Visdom中的两个重要概念:

    • env: 环境。不同用户、不同程序一般使用不同 env. 不同 env 相互独立,互不影响。使用时不指定 env ,则默认使用main
    • pane: 窗格。 一个env中可以有多个不同的pane, 每个pane可视化或记录某一信息,可以拖动、缩放、保存或关闭
    In:
    import visdom
    vis = visdom.Visdom(env=u'test1')    # 构建一个客户端对象,创建一个名为' test1' 的 env 
    x = torch.arange(1, 30, 0.01)
    y = torch.sin(x)
    vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})   #  win 是 pane 名字,opts 设置 pane 格式,如 title, xlabel,ylabel
    

    vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'}) 中, win 参数指定 pane 名字, 如果不指定,visdom将自动分配一个新的pane. 如果两次操做指定的win名字一样,新操作将覆盖当前 pane 的内容。如在 上面的 pane中画 y = x 函数,原来的 y = sin(x) 将被覆盖。
    In:
    y = x
    vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=x'})
    

    如果不想覆盖原图,可以使用updateTrace方法,如:
    y = x + 1
    vis.updateTrace(X=x, Y=y, win='sinx', name='this is a new Trace')
    

    未完待续

    相关文章

      网友评论

      • IT人故事会:做开发很累,还的学习,之前你这个我也碰到过,但是没记录谢谢了
        月牙眼的楼下小黑:@IT人故事会 不不不,你还是很厉害的,视野广阔,啥都碰过,而且记忆不错,两次评论一模一样,严谨精准,跟机器人一样

      本文标题:Pytorch学习(3): 常用工具模块简介

      本文链接:https://www.haomeiwen.com/subject/wvblhftx.html