美文网首页
数据增强

数据增强

作者: 小黄不头秃 | 来源:发表于2022-09-12 00:29 被阅读0次

(一)数据增强(增广)

(1)为什么要做数据增强?

一个原因是可能你的数据集比较小,所以需要对数据进行简单的操作,让数据集增加。第二是有这样的一个真实实例。有一家做智能售货机的公司在公司内部调试好参数训练好模型以后,将售货机拿去展厅进行测试的时候,发现原本准确率非常高的机器忽然识别不出来了。原因是展厅的光源不一样,导致整个的测试数据集就和训练集发生了很大的变化。所以在产品研发的时候适当的通过数据增强技术能够给模型增加鲁棒性。

(2)数据增强方法有什么?

可通过在图片中加入各种不一样的背景噪音,改变图片的颜色和形状。

  • 翻转:左右翻转、上下翻转
  • 切割:在图片中切割出一块,然后变形到固定的形状
  • 颜色:改变色调,饱和度,明度等

(二)代码实现

%matplotlib inline
import torch 
import torchvision
from torch import nn 
from d2l import torch as d2l
import matplotlib.image as img
import matplotlib.pyplot as plt

# 打开图片的方法
# image = img.imread('../img/cat1.jpg')
# plt.title("cat.jpg")
# plt.axis("off")
# plt.imshow(image)
# plt.show()

d2l.set_figsize()
img = d2l.Image.open('../img/cat1.jpg')
d2l.plt.imshow(img)
# 参数列表(图片,增强的办法,多少行,多少列,倍数)
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    y = [aug(img) for _ in range(num_rows*num_cols)]
    d2l.show_images(y, num_rows, num_cols, scale=scale)

# 左右翻转图片
apply(img,torchvision.transforms.RandomHorizontalFlip())

# 上下翻转
apply(img,torchvision.transforms.RandomVerticalFlip())

# 随即裁剪
shape_aug = torchvision.transforms.RandomResizedCrop(
    # (输出大小,选择的比例,高宽比)
    size=(200,200),scale=(0.1,1),ratio=(0.5,2))
apply(img,shape_aug)
# 随机更改图片的亮度
apply(img,torchvision.transforms.ColorJitter(
    # (亮度区间,对比度,饱和度,色调)
    brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5
))
# 结合多种数据增强方法
augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    shape_aug,
    torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
])

apply(img, augs)
# 如果下载报错的话,自己去网页上下载
# https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
all_images = torchvision.datasets.CIFAR10(
    train=True,
    root="../data",
    download=False
)

d2l.show_images([all_images[i][0] for i in range(32)],4,8,scale=0.8)
# d2l.show_images([all_images.data[i] for i in range(32)],4,8,scale=0.8)
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()
])
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
def load_cifar10(is_train,augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(
        train=is_train,
        root="../data",
        download=False,
        transform=augs,
    )
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    return data_loader
def train_batch_ch13(net, X, y, loss, trainer, devices):
    """用多GPU进行小批量训练"""
    if isinstance(X, list):
        # 微调BERT中所需(稍后讨论)
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum
#@save
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """用多GPU进行模型训练"""
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # 4个维度:储存训练损失,训练准确度,实例数,特点数
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(
                net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')
batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
train_with_data_aug(train_augs, test_augs, net)

相关文章

  • 猫狗大战-是否加数据增强(augmentation)的区别

    没有带数据增强 加了数据增强之后

  • 数据增强

    《Batch Normalization: Accelerating Deep Network Training ...

  • 数据增强

  • 数据增强

    (一)数据增强(增广) (1)为什么要做数据增强? 一个原因是可能你的数据集比较小,所以需要对数据进行简单的操作,...

  • 深度学习中的数据增强与实现

    1. 数据增强Data Augmentation 数据增强让有限的数据产生更多的数据,增加训练样本的数量以及多样性...

  • 深度学习数据增强概括二

    接之前深度学习数据增强概括一 基于DL的数据增强 1、Feature space augmentation特征空间...

  • 数据色彩增强

  • NLP 数据增强

    在机器学习领域,个人觉得有一个大前提:数据是永远不够的。虽然现在有很多吹嘘大数据,在自然语言处理领域,标注数据尤其...

  • 图像数据增强

    一、做随机亮度、对比度、饱和度修改,使用tensorflow API核心部分是 aug_op 函数,这可是菜鸟的心血啊!

  • 使用数据增强

    发生了过拟合。训练精度随着时间线性增加,直到接近 100%,而验证精度则停留在 70%~72%。验证损失仅在 5 ...

网友评论

      本文标题:数据增强

      本文链接:https://www.haomeiwen.com/subject/gmybnrtx.html