美文网首页Pytorch
Pytorch实现Lenet5模型(FashionMNIST)

Pytorch实现Lenet5模型(FashionMNIST)

作者: Lornatang | 来源:发表于2018-08-11 12:04 被阅读188次

不说废话,直接上代码。

"""
# author: shiyipaisizuo
# contact: shiyipaisizuo@gmail.com
# file: lenet5.py
# time: 2018/7/31 10:06
# license: MIT
"""

import argparse
import os
import time

import torch
import torchvision
from torchvision import transforms
from torch import nn as nn
from torch.optim import Adam

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default='../data/fashion',
                    help="""image path. Default='../data/fashion'.""")
parser.add_argument('--epochs', type=int, default=200,
                    help="""num epochs. Default=200""")
parser.add_argument('--num_classes', type=int, default=10,
                    help="""0 ~ 9,. Default=10""")
parser.add_argument('--batch_size', type=int, default=100,
                    help="""batch size. Default=100""")
parser.add_argument('--lr', type=float, default=0.001,
                    help="""learing_rate. Default=0.001""")
parser.add_argument('--model_path', type=str, default='../../model/pytorch/mnist/fashion_mnist',
                    help="""Save model path""")
parser.add_argument('--model_name', type=str, default='lenet5.pth',
                    help="""Model name.""")
parser.add_argument('--display_epoch', type=int, default=2)
args = parser.parse_args()

# Create model
if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

# Define transforms.
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
test_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

# Fashion mnist dataset
train_dataset = torchvision.datasets.FashionMNIST(root=args.path,
                                                  train=True,
                                                  transform=train_transform,
                                                  download=True)

test_dataset = torchvision.datasets.FashionMNIST(root=args.path,
                                                 train=False,
                                                 transform=test_transform,
                                                 download=True)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False)


# Create nerual network
class LeNet(nn.Module):
    def __init__(self, category=args.num_classes):
        super(LeNet, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(1, 6, 3, stride=1, padding=1),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5, stride=1, padding=0),
            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            nn.Linear(400, 120),
            nn.Linear(120, 84),
            nn.Linear(84, category)
        )

    def forward(self, x):
        out = self.layer(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


# Load model
model = LeNet().to(device)
print(LeNet())
# cast
cast = nn.CrossEntropyLoss()
# Optimization
optimizer = Adam(model.parameters(), lr=args.lr)


def main():
    model.train()
    for epoch in range(1, args.epochs + 1):
        start = time.time()
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = cast(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % args.display_epoch == 0 or epoch == 1:
            end = time.time()
            print(f"Epoch [{epoch}/{args.epochs}], "
                  f"Loss: {loss.item():.8f}, "
                  f"Time: {(end-start):.1f}sec!")

    # Test the model
    model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Test Accuracy: {(correct / args.batch_size):.2f}%")

    # Save the model checkpoint
    torch.save(model, args.model_path + args.model_name)


if __name__ == '__main__':
    main()
"""
Acc: 0.993
"""

原文地址

相关文章

  • Pytorch实现Lenet5模型(FashionMNIST)

    不说废话,直接上代码。 原文地址

  • PyTorch实现经典网络之LeNet5

    简介 本文是使用PyTorch来实现经典神经网络结构LeNet5,并将其用于处理MNIST数据集。LeNet5出自...

  • 人人都是毕加索

    基于 Pytorch 和 VGG19 模型实现图片风格迁移。 相关 Pytorch 官方教程 相关 Github ...

  • CV-字符识别模型

    Pytorch构建CNN模型 Pytorch中构建CNN模型只需要定义好模型的参数和正向传播就可以,Pytorch...

  • 动手学深度学习(一) 线性回归

    线性回归 主要内容包括: 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用pytorch的简洁实现...

  • 线性回归

    线性回归 主要内容包括: 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用pytorch的简洁实现...

  • 第一天-线性回归,Softmax与分类模型,多层感知机

    线性回归 主要内容包括: 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用pytorch的简洁实现...

  • VAE模型实现(pytorch)

    参考论文:官方模型 1、网络结构 结构如下,中间层合并是再参数化技巧,具体可以推导可以参考:https://zhu...

  • 存下-有空总结看看

    教程 | 从头开始了解PyTorch的简单实现 线性模型已退场,XGBoost时代早已来

  • 打卡2020-02-14

    线性回归 从零开始的实现 使用pytorch的简洁实现 softmax与分类模型 softmax某个类别的soft...

网友评论

    本文标题:Pytorch实现Lenet5模型(FashionMNIST)

    本文链接:https://www.haomeiwen.com/subject/jmfdbftx.html