美文网首页
pytorch 迁移学习

pytorch 迁移学习

作者: zidea | 来源:发表于2020-08-10 20:55 被阅读0次
pytorch_bannar.png

今天我们案例取材于 pytorch 的官方教程,使用迁移学习来训练神经网络来做图片分类。
Finetuning the convnet: 在此我们不再随机初始化化参数,而是利用某些预训练网络来初始化网络参数,这样我们网络就在基于 imagenet 1000 数据集基础上更进一步训练自己数据集
ConvNet as fixed feature extractor: 我们可以通过冻结出最后全连接层以外的所有其他层,进行训练,替换后最后一层变为随机参数全连接层

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

加载数据

使用 torchvisiontorch.utils.data 包来加载此数据。今天的任务是训练一个可以区分蚂蚁和蜜蜂的模型。对于每一个类别我们各有 120 训练图片和 75 验证图片。选择小数据集的目的也是在想要说明如何通过迁移学习来训练一个非常小数据集。这是 imagenet 图片数据集的一小部分。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  • data_transforms 将数据转换为标准化[0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ,使用 ToTensor 将图片转换为 tensor,使用 Centercrop 对图片随机裁剪从而达到图片增强的目标。
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

相关文章

  • pytorch 迁移学习

    今天我们案例取材于 pytorch 的官方教程,使用迁移学习来训练神经网络来做图片分类。Finetuning th...

  • [PyTorch]迁移学习

    背景预训练模型使用的训练数据并非训练集,可能来自ImageNet数据库等用于图像分类,像素语义分割,对象检测,实例...

  • 迁移学习pytorch

    以VGG网络为例:1、只调整一层:以后禁止使用这种写法 2、调整整个classifier层:要调整把整个分类层都要...

  • Pytorch 迁移学习

    概念 迁移学习简单来说就是使用别人已经训练好的模型的参数,并根据需求修改模型。比如vgg模型默认是输入一张三通道的...

  • 迁移学习_pytorch简单实战

    迁移学习_pytorch实战 想学习一下迁移学习,则将使用预先训练的网络,来构建用于疟疾检测的图像分类器,这个分类...

  • 2019-04-29 手把手教你用pytorch实现迁移学习(第

    一 前言 1 前提条件: 这是一篇小白实现迁移学习的文章,本文不需要事先掌握Pytorch或者迁移学习的相关知识,...

  • 人人都是毕加索

    基于 Pytorch 和 VGG19 模型实现图片风格迁移。 相关 Pytorch 官方教程 相关 Github ...

  • Pytorch学习记录-使用Pytorch进行深度学习,迁移学习

    迁移学习(Transfer Learning)在完成60分钟入门之后,接下来有六节tutorials和五节关于文本...

  • 数据不足,如何进行迁移学习?

    摘要: 在没有足够的训练数据时,本文详细介绍了如何使用FloydHub、fast.ai和PyTorch进行迁移学习...

  • pytorch学习(十二)—迁移学习Transfer Learn

    前言 在训练深度学习模型时,有时候我们没有海量的训练样本,只有少数的训练样本(比如几百个图片),几百个训练样本显然...

网友评论

      本文标题:pytorch 迁移学习

      本文链接:https://www.haomeiwen.com/subject/kfcsdktx.html