美文网首页
图像分类

图像分类

作者: 潘旭 | 来源:发表于2020-04-17 18:27 被阅读0次

李沐 《动手学深度学习》 第三章

from typing import Iterable

import torch
import torchvision
def load_data_fashion_mnist(batch_size, resize=None, root='./FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory."""
    trans = []
    if resize:
        trans.append(torchvision.transforms.Resize(size=resize))
    trans.append(torchvision.transforms.ToTensor())

    transform = torchvision.transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)

    num_workers = 4

    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_iter, test_iter

读取数据

batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

定义和初始化模型

num_inputs = 784
num_output = 10

class FlattenLayer(torch.nn.Module):
    """
    将 x 打平
    """

    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        将 x 打平, x.shape: (B, 28, 28) 打平后变成 x.shape: (B, 28*28)
        :param x:
        :return:
        """
        return x.view(x.shape[0], -1)


class Net(torch.nn.Module):
    """
    线性模型
    """

    def __init__(self):
        super().__init__()
        self.flatten = FlattenLayer()
        self.linear = torch.nn.Linear(in_features=num_inputs,
                                      out_features=num_output)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        进行运算
        :param x:
        :return:
        """
        x = self.flatten(x)
        y = self.linear(x)
        return y


net = Net()

print("net ...")

初始化参数

torch.nn.init.normal_(net.linear.weight, mean=0, std=0.01)
torch.nn.init.constant_(net.linear.bias, val=0)

也可以通过 name paramter 来初始化

for name, param in net.named_parameters():
    if name.endswith(".weight"):
        torch.nn.init.normal_(param, mean=0., std=0.01)
    elif name.endswith(".bias"):
        torch.nn.init.constant_(param, val=0)
    else:
        raise RuntimeError(f"{name} not be init")

定义损失函数

loss = torch.nn.CrossEntropyLoss(reduction="mean")

定义优化算法

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

训练模型

num_epochs = 5

def train():
    print("begin train...")

    for epoch in range(1, num_epochs + 1):
        total_loss = 0.
        total = 0
        true_sum = 0

        for x, y in train_iter:

            y_hat = net(x)

            # y_hat.shape: B*num_output, y.shape: B*1
            # ll 是标量, 默认mean
            ll = loss(y_hat, y)

            # 进行优化
            # 清空grad
            optimizer.zero_grad()

            ll.backward()

            optimizer.step()

            # ll是 mean, 所以需要乘以 y.shape[0]
            total += y.shape[0]
            total_loss += ll * y.shape[0]

            # 计算 acc
            true_sum += torch.sum(torch.argmax(y_hat, dim=-1) == y).item()

        epoch_loss = total_loss / total
        acc = true_sum / total
        print(f"epoch: {epoch}, total: {total}, loss: {epoch_loss}, acc: {acc}")

开始训练

train()

相关文章

  • 图像分类

    图像分类入门 -图像分类的概念 背景与意义 所谓图像分类问题,就是已有固定的分类标签集合,然后对于输入的图像,从分...

  • python计算机视觉深度学习3图像分类基础

    什么是图像分类? 图像分类的核心任务是从预定义的一类图像中为图像分配标签。分析输入图像并返回标签对图像进行分类。标...

  • python计算机视觉深度学习工具3图像分类基础

    什么是图像分类? 图像分类的核心任务是从预定义的一类图像中为图像分配标签。分析输入图像并返回标签对图像进行分类。标...

  • 标注组件-react版

    组件支持标注类型:1、图像 — 浏览、标注集合展示2、图像分类 — 支持对图像进行分类标注3、图像检测 — 支持对...

  • 2018-10-17

    数据驱动方法 图像分类 图像分类时,分类系统接受一些输入图像,比如猫咪,并且系统已经清楚了一些确定分类或者标签的集...

  • [CS231n]Lecture 2 Image Classifi

    本节内容:图像分类概述、KNN、线性分类器 图像分类是计算机视觉的核心问题。 问题 :语义分割的鸿沟。图像仅仅是一...

  • 机器学习5(轻量TensorFlow)教程

    2. 图像分类器 底层技术依靠TensorFlow实现,此图像分类器利用了Mobilenet分类模型 2.1. 用...

  • 图像分类

    Above All 机器学习的大作业是写图像分类。这里我整理一些有用的参考资料,以便后来提交报告的时候逻辑比较清晰...

  • 图像分类

    Lecture 2: Image Classification pipeline Image Classifica...

  • 图像分类

    李沐 《动手学深度学习》 第三章 读取数据 定义和初始化模型 初始化参数 也可以通过 name paramter ...

网友评论

      本文标题:图像分类

      本文链接:https://www.haomeiwen.com/subject/rzoyvhtx.html