美文网首页
数据增强

数据增强

作者: 小黄不头秃 | 来源:发表于2022-09-12 00:29 被阅读0次

    (一)数据增强(增广)

    (1)为什么要做数据增强?

    一个原因是可能你的数据集比较小,所以需要对数据进行简单的操作,让数据集增加。第二是有这样的一个真实实例。有一家做智能售货机的公司在公司内部调试好参数训练好模型以后,将售货机拿去展厅进行测试的时候,发现原本准确率非常高的机器忽然识别不出来了。原因是展厅的光源不一样,导致整个的测试数据集就和训练集发生了很大的变化。所以在产品研发的时候适当的通过数据增强技术能够给模型增加鲁棒性。

    (2)数据增强方法有什么?

    可通过在图片中加入各种不一样的背景噪音,改变图片的颜色和形状。

    • 翻转:左右翻转、上下翻转
    • 切割:在图片中切割出一块,然后变形到固定的形状
    • 颜色:改变色调,饱和度,明度等

    (二)代码实现

    %matplotlib inline
    import torch 
    import torchvision
    from torch import nn 
    from d2l import torch as d2l
    import matplotlib.image as img
    import matplotlib.pyplot as plt
    
    # 打开图片的方法
    # image = img.imread('../img/cat1.jpg')
    # plt.title("cat.jpg")
    # plt.axis("off")
    # plt.imshow(image)
    # plt.show()
    
    d2l.set_figsize()
    img = d2l.Image.open('../img/cat1.jpg')
    d2l.plt.imshow(img)
    
    # 参数列表(图片,增强的办法,多少行,多少列,倍数)
    def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
        y = [aug(img) for _ in range(num_rows*num_cols)]
        d2l.show_images(y, num_rows, num_cols, scale=scale)
    
    # 左右翻转图片
    apply(img,torchvision.transforms.RandomHorizontalFlip())
    
    # 上下翻转
    apply(img,torchvision.transforms.RandomVerticalFlip())
    
    # 随即裁剪
    shape_aug = torchvision.transforms.RandomResizedCrop(
        # (输出大小,选择的比例,高宽比)
        size=(200,200),scale=(0.1,1),ratio=(0.5,2))
    apply(img,shape_aug)
    
    # 随机更改图片的亮度
    apply(img,torchvision.transforms.ColorJitter(
        # (亮度区间,对比度,饱和度,色调)
        brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5
    ))
    
    # 结合多种数据增强方法
    augs = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        shape_aug,
        torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
    ])
    
    apply(img, augs)
    
    # 如果下载报错的话,自己去网页上下载
    # https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    all_images = torchvision.datasets.CIFAR10(
        train=True,
        root="../data",
        download=False
    )
    
    d2l.show_images([all_images[i][0] for i in range(32)],4,8,scale=0.8)
    # d2l.show_images([all_images.data[i] for i in range(32)],4,8,scale=0.8)
    
    train_augs = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor()
    ])
    test_augs = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    def load_cifar10(is_train,augs, batch_size):
        dataset = torchvision.datasets.CIFAR10(
            train=is_train,
            root="../data",
            download=False,
            transform=augs,
        )
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0
        )
        return data_loader
    
    def train_batch_ch13(net, X, y, loss, trainer, devices):
        """用多GPU进行小批量训练"""
        if isinstance(X, list):
            # 微调BERT中所需(稍后讨论)
            X = [x.to(devices[0]) for x in X]
        else:
            X = X.to(devices[0])
        y = y.to(devices[0])
        net.train()
        trainer.zero_grad()
        pred = net(X)
        l = loss(pred, y)
        l.sum().backward()
        trainer.step()
        train_loss_sum = l.sum()
        train_acc_sum = d2l.accuracy(pred, y)
        return train_loss_sum, train_acc_sum
    
    #@save
    def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices=d2l.try_all_gpus()):
        """用多GPU进行模型训练"""
        timer, num_batches = d2l.Timer(), len(train_iter)
        animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                                legend=['train loss', 'train acc', 'test acc'])
        net = nn.DataParallel(net, device_ids=devices).to(devices[0])
        for epoch in range(num_epochs):
            # 4个维度:储存训练损失,训练准确度,实例数,特点数
            metric = d2l.Accumulator(4)
            for i, (features, labels) in enumerate(train_iter):
                timer.start()
                l, acc = train_batch_ch13(
                    net, features, labels, loss, trainer, devices)
                metric.add(l, acc, labels.shape[0], labels.numel())
                timer.stop()
                if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                    animator.add(epoch + (i + 1) / num_batches,
                                 (metric[0] / metric[2], metric[1] / metric[3],
                                  None))
            test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
            animator.add(epoch + 1, (None, None, test_acc))
        print(f'loss {metric[0] / metric[2]:.3f}, train acc '
              f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
        print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
              f'{str(devices)}')
    
    batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)
    
    def init_weights(m):
        if type(m) in [nn.Linear, nn.Conv2d]:
            nn.init.xavier_uniform_(m.weight)
    
    net.apply(init_weights)
    
    def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
        train_iter = load_cifar10(True, train_augs, batch_size)
        test_iter = load_cifar10(False, test_augs, batch_size)
        loss = nn.CrossEntropyLoss(reduction="none")
        trainer = torch.optim.Adam(net.parameters(), lr=lr)
        train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
    
    train_with_data_aug(train_augs, test_augs, net)
    

    相关文章

      网友评论

          本文标题:数据增强

          本文链接:https://www.haomeiwen.com/subject/gmybnrtx.html