美文网首页
PyTorch学习笔记7 - 使用PyTorch完成手写体数字识

PyTorch学习笔记7 - 使用PyTorch完成手写体数字识

作者: 小新_XX | 来源:发表于2019-04-18 21:53 被阅读0次

本篇笔记的完整代码:https://github.com/ChenWentai/PyTorch/blob/master/task7.py

本篇笔记是PyTorch学习的最后一篇,使用PyTorch搭建神经网络,完成手写体数字识别问题。

1. 导入数据并可视化

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
#定义batch size
batch_size = 64

#下载MNIST数据集
train_dataset = datasets.MNIST(root='./data/',
                               train=True,
                               transform=transforms.ToTensor(), 
                               download=True)
test_dataset = datasets.MNIST(root='./data/',
                               train=False,
                               transform=transforms.ToTensor())

#将下载的MNIST数据导入到dataloader中
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

#数据可视化
import matplotlib.pyplot as plt

plt.figure()
plt.imshow(train_loader.dataset.train_data[0].numpy())
plt.show()
data

2. 搭建神经网络

本文采用1994年由Yann Le Cun提出的大名鼎鼎的LeNet5作为模型。LeNet5包含2个5x5的卷积层,2个2x2的池化层和3个全连接层,其结构示意图如下:


lenet5

代码:

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.max_pool2d(F.tanh(self.conv1(x)), (2,2))
#         x = F.dropout(x, p=0.3, training = self.training)
        x = F.max_pool2d(F.tanh(self.conv2(x)), (2,2))
#         x = F.dropout(x, p=0.3, training = self.training)
        x = x.view(-1, self.num_flat_features(x))
        x = F.tanh(self.fc1(x))
#         x = F.dropout(x, p=0.3, training = self.training)
        x = F.tanh(self.fc2(x))
#         x = F.dropout(x, p=0.3, training = self.training)
        x = self.fc3(x)
        return x
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

3. 训练神经网络

from torch.autograd import Variable
from time import time
def train(model, num_epoch):
    Loss = []
    model.train(True)
    for i in range(num_epoch):
        running_loss = 0
        running_corrects = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = Variable(data), Variable(target)
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            running_corrects += pred.eq(target.data.view_as(pred)).cpu().sum()
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()            
            running_loss += loss.data.item()
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = running_corrects.data.item() / len(train_dataset)
        print("Epoch:", i,"running_loss:", running_loss, "Loss:", epoch_loss)
        print("Epoch:", i,"running_corrects:", running_corrects.data.item(), "acc:", epoch_acc)
        Loss.append(epoch_loss)
    return Loss
model = LeNet5()
#查看网络结构
print(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
criterion = nn.CrossEntropyLoss()
a = time()
LOSS = train(model, 30)    
b = time() - a
print("training time:{}s".format(b))

可视化训练结果

plt.plot(LOSS)
plt.xlabel('epoch')
plt.ylabel('loss')
train_loss.png

3. 测试神经网络

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        #将每个batch的loss加和
        test_loss += criterion(output, target).data.numpy()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    test_loss /= len(test_loader.dataset)
    print("test_loss:",test_loss)
    acc = correct.item()/len(test_dataset)
    print("test accuracy:", acc)
test()

结果:

test_loss: 0.0011666176065802574
test accuracy: 0.9841

相关文章

网友评论

      本文标题:PyTorch学习笔记7 - 使用PyTorch完成手写体数字识

      本文链接:https://www.haomeiwen.com/subject/ltjxgqtx.html