美文网首页
pytorch笔记01-数据增强

pytorch笔记01-数据增强

作者: 坐下等雨 | 来源:发表于2019-08-18 17:16 被阅读0次

    1、什么是数据增强

    数据增强是扩充数据样本规模的一种有效地方法。深度学习是基于大数据的一种方法,我们当前希望数据的规模越大、质量越高越好。模型才能够有着更好的泛化能力,然而实际采集数据的时候,往往很难覆盖掉全部的场景,比如:对于光照条件,在采集图像数据时,我们很难控制光线的比例,因此在训练模型的时候,就需要加入光照变化方面的数据增强。再有一方面就是数据的获取也需要大量的成本,如果能够自动化的生成各种训练数据,就能做到更好的开源节流。

    2、数据增强的作用

    • 增加训练的数据量,提高模型的泛化能力
    • 增加噪声数据,提升模型的鲁棒性

    3、如何进行数据增强

    数据增强可以分为两类,一类是离线增强,一类是在线增强。

    • 离线增强 : 直接对数据集进行处理,数据的数目会变成增强因子 x 原数据集的数目 ,这种方法常常用于数据集很小的时候.
    • 在线增强 : 这种增强的方法用于,获得 batch 数据之后,然后对这个 batch 的数据进行增强,如旋转、平移、翻折等相应的变化,由于有些数据集不能接受线性级别的增长,这种方法长用于大的数据集,很多机器学习框架已经支持了这种数据增强方式,并且可以使用 GPU 优化计算。

    4、pytorch数据增强操作

    pytorch中数据增强的常用方法如下:

    • 对图片进行一定比例的缩放
    • 对图片进行随机的截取
    • 对图片进行随机水平和竖直翻转
    • 对图片进行随机角度的旋转
    • 对图片进行亮度、对比度和颜色的随机变化等

    torchvision中内置的transforms包含了这些些常用的图像变换,这些变换能够用Compose串联组合起来。

    from PIL import Image
    from torchvision import transforms as tfs
    
    img = Image.open('./dog.jpg')
    print('原图:')
    img
    

    原图:


    4.1、中心处裁剪PIL图片

    class torchvision.transforms.CenterCrop(size)

    • size(序列 或 int)– 需要裁剪出的形状。如果size是int,将会裁剪成正方形;如果是形如(h, w)的序列,将会裁剪成矩形。
    print('原图像尺寸:{}'.format(img.size))
    re_img = tfs.CenterCrop(200)(img)
    print('中心裁剪后尺寸:{}'.format(re_img.size))
    re_img
    

    原图像尺寸:(658, 411)
    中心裁剪后尺寸:(200, 200)


    4.2 随机改变图片的亮度、对比度和饱和度

    class torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

    • brightness(float或 float类型元组(min, max))– 亮度的扰动幅度。应当是非负数。
    • contrast(float或 float类型元组(min, max))– 对比度扰动幅度。应当是非负数。
      saturation(float或 float类型元组(min, max))– 饱和度扰动幅度。应当是非负数。
      hue(float或 float类型元组(min, max))– 色相扰动幅度。hue_factor从[-hue, hue]中随机采样产生,其值应当满足0<= hue <= 0.5或-0.5 <= min <= max <= 0.5
    cj_img = tfs.ColorJitter(0.8, 0.8, 0.5)(img)
    cj_img
    

    4.3 图片转换为灰阶

    class torchvision.transforms.Grayscale(num_output_channels=1))

    • num_output_channels(int,1或3)– 希望得到的图片通道数。
    gc_img = tfs.Grayscale(1)(img)
    gc_img
    

    4.4 图像的各条边缘进行扩展

    class torchvision.transforms.Pad(padding, fill=0, padding_mode='constant')

    • padding(int 或 tuple)– 在每条边上展开的宽度。如果传入的是单个int,就在所有边展开。如果传入长为2的元组,则指定左右和上下的展开宽度。如果传入长为4的元组,则依次指定为左、上、右、下的展开宽度。
    • fill(int 或 tuple) – 像素填充值。默认是0。如果指定长度为3的元组,表示分别填充R, G, B通道。这个参数仅在padding_mode是‘constant’时指定有效。
    • padding_mode(str)– 展开类型。应当是‘constant’,‘edge’,‘reflect’或‘symmetric’之一。默认为‘constant’。
      • constant:用常数扩展,这个值由fill参数指定。
      • edge:用图像边缘上的值填充。
      • reflect:以边缘为对称轴进行轴对称填充(边缘值不重复)。
      • symmetric:用图像边缘的反转进行填充(图像的边缘值需要重复)。
    # 用常数0填充
    con_img = tfs.Pad(50, fill=0, padding_mode='constant')(img)
    con_img
    
    # 用图像边缘值填充
    edge_img = tfs.Pad(50, fill=0, padding_mode='edge')(img)
    edge_img
    
    # 以边缘为对称轴进行轴对称填充
    ref_img = tfs.Pad(50, fill=0, padding_mode='reflect')(img)
    ref_img
    

    4.5 图片在随机位置处进行裁剪

    class torchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)

    • size(序列 或 int)– 想要裁剪出的图片的形状。如果size是int,按照正方形(size, size)裁剪; 如果size是序列(h, w),裁剪为矩形。
    • padding(int 或 序列 , 可选)– 在图像的边缘进行填充,默认0,即不做填充。如果指定长为4的序列,则分别指定左、上、右、下的填充宽度。
    • pad_if_needed(boolean)– 如果设置为True,若图片小于目标形状,将进行填充以避免报异常。
    rc_img = tfs.RandomCrop(200)(img)
    rc_img
    

    4.6 以给定的概率随机水平翻折PIL图片

    class torchvision.transforms.RandomHorizontalFlip(p=0.5)

    • p(float)– 翻折图片的概率。默认0.5。
    rh_img = tfs.RandomHorizontalFlip(1)(img)
    rh_img
    

    4.7 以给定的概率随机垂直翻折PIL图片

    class torchvision.transforms.RandomVerticalFlip(p=0.5)

    • p(float)– 翻折图片的概率。默认0.5。
    rv_img = tfs.RandomVerticalFlip(1)(img)
    rv_img
    

    4.8 以指定的角度选装图片

    class torchvision.transforms.RandomRotation(degrees, resample=False, expand=False, center=None)

    • degrees(序列 或 float or int)– 旋转角度的随机选取范围。如果degrees是序列(min, max),则从中随机选取;如果是数字,则选择范围是(-degrees, +degrees)。
    • resample({PIL.Image.NEAREST , PIL.Image.BILINEAR , PIL.Image.BICUBIC} , 可选) – 可选的重采样滤波器。如果该选项忽略,或图片模式是“1”或者“P”则设置为PIL.Image.NEAREST。
    • expand(bool, 可选)– 可选的扩展标志。如果设置为True, 将输出扩展到足够大从而能容纳全图。如果设置为False或不设置,输出图片将和输入同样大。注意expand标志要求 flag assumes rotation around the center and no translation。
    • center(2-tuple , 可选)– 可选的旋转中心坐标。以左上角为原点计算。默认是图像中心。
    rr_img = tfs.RandomRotation(45)(img)
    rr_img
    

    以上都是对图像做单次变换,torchvision提供torchvision.transforms.Compose()函数,可以将以上图像方法联合起来使用,比如先做随机翻转,然后随机截取,再做对比度增强等。
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    aug_img = tfs.Compose([
        tfs.Resize(200),
        tfs.RandomHorizontalFlip(),
        tfs.RandomCrop(120),
    tfs.RandomVerticalFlip(),
        tfs.ColorJitter(0.5, 0.5, 0.5)
    ])
    
    _, figs = plt.subplots(3, 3, figsize=(10, 10))
    for i in range(3):
        for j in range(3):
            figs[i][j].imshow(aug_img(img))
            figs[i][j].axes.get_xaxis().set_visible(False)
            figs[i][j].axes.get_yaxis().set_visible(False)
    

    相关文章

      网友评论

          本文标题:pytorch笔记01-数据增强

          本文链接:https://www.haomeiwen.com/subject/przrsctx.html