美文网首页
pytorch 迁移学习

pytorch 迁移学习

作者: zidea | 来源:发表于2020-08-10 20:55 被阅读0次
    pytorch_bannar.png

    今天我们案例取材于 pytorch 的官方教程,使用迁移学习来训练神经网络来做图片分类。
    Finetuning the convnet: 在此我们不再随机初始化化参数,而是利用某些预训练网络来初始化网络参数,这样我们网络就在基于 imagenet 1000 数据集基础上更进一步训练自己数据集
    ConvNet as fixed feature extractor: 我们可以通过冻结出最后全连接层以外的所有其他层,进行训练,替换后最后一层变为随机参数全连接层

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.optim import lr_scheduler
    import numpy as np
    import torchvision
    from torchvision import datasets, models, transforms
    import matplotlib.pyplot as plt
    import time
    import os
    import copy
    
    plt.ion()   # interactive mode
    

    加载数据

    使用 torchvisiontorch.utils.data 包来加载此数据。今天的任务是训练一个可以区分蚂蚁和蜜蜂的模型。对于每一个类别我们各有 120 训练图片和 75 验证图片。选择小数据集的目的也是在想要说明如何通过迁移学习来训练一个非常小数据集。这是 imagenet 图片数据集的一小部分。

    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    data_dir = 'data/hymenoptera_data'
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x])
                      for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                                 shuffle=True, num_workers=4)
                  for x in ['train', 'val']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    class_names = image_datasets['train'].classes
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    • data_transforms 将数据转换为标准化[0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ,使用 ToTensor 将图片转换为 tensor,使用 Centercrop 对图片随机裁剪从而达到图片增强的目标。
    def imshow(inp, title=None):
        """Imshow for Tensor."""
        inp = inp.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        inp = std * inp + mean
        inp = np.clip(inp, 0, 1)
        plt.imshow(inp)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)  # pause a bit so that plots are updated
    
    
    # Get a batch of training data
    inputs, classes = next(iter(dataloaders['train']))
    
    # Make a grid from batch
    out = torchvision.utils.make_grid(inputs)
    
    imshow(out, title=[class_names[x] for x in classes])
    

    相关文章

      网友评论

          本文标题:pytorch 迁移学习

          本文链接:https://www.haomeiwen.com/subject/kfcsdktx.html