美文网首页图像分类
pytorch图像分类问题完整流程

pytorch图像分类问题完整流程

作者: 1037号森林里一段干木头 | 来源:发表于2021-05-24 09:31 被阅读0次

简介:本篇文章展示pytorch做图像分类的完整过程。因为在我的应用场景下图片特征简单,对计算速度有要求,所以把网络模型写得很小(当然最终的模型要保密啦),加入了SPPnet对输入的图片尺寸没有要求。

我的训练数据集结构如下:

在这里插入图片描述
数据集划分参考

pytorch图像分类完整流程如下

  • 导入依赖库
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import torchvision
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
  • 模型,带SPPnet对输入图像尺寸没有要求
class net(nn.Module):
    def __init__(self,channels=3, height=128, width=128, numLevels=3):
        super(net, self).__init__()#父类初始化
        self.numLevels = numLevels
        self.conv1 = nn.Conv2d(3,16,3)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 64, 3)
        self.fc1 = nn.Linear(896,64)#3层的SPPnet决定的特征数,与图片尺寸无关
        self.fc2 = nn.Linear(64,2)
        
    def SPPLayer(self,x):
        num, c, h, w = x.size() # num:样本数量 c:通道数 h:高 w:宽
        for i in range(self.numLevels):
            level = i+1
            kernel_size = (math.ceil(h / level), math.ceil(w / level))
            stride = (math.ceil(h / level), math.ceil(w / level))
            pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2))
            tensor = F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
           
            # 展开、拼接
            if (i == 0):
                x_flatten = tensor.view(num, -1)
            else:
                x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1)
        return x_flatten
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.SPPLayer(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        
  • 测试模型是否正确,如果正常输出则表示模型结构正确
image_w,image_h = [125,127]
model  = net(3,image_w,image_h)
x = torch.ones(1,3,image_w,image_h) 
#model.eval()
y = model(x)
y.size()
torch.Size([1, 2])
  • 定义验证图片是否正常的函数
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False
  • 定义图像预处理操作集
img_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5,0.5,0.5],
                        std=[0.3,0.3,0.3])
])
  • 定义训练、测试、验证集
train_data_path = r"K:\imageData\polarity\data3\train"
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms,is_valid_file=check_image)
#test_data_path = r"K:\imageData\polarity\data2\test"
#test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms,is_valid_file=check_image)
val_data_path = r"K:\imageData\polarity\data3\val"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms,is_valid_file=check_image)
  • 定义数据加载器
batch_size = 32
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
#test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
  • 定义训练过程
def train(model, optimizer, loss_fn, train_loader, val_loader, epoches=30, device=torch.device("cpu")):
    train_loss_list = []
    valid_loss_list = []
    valid_accuracy_list = []
    epoch_list = []
    for epoch in range(1,epoches+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs,targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            outputs = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(outputs, targets)
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(outputs, dim=1), dim=1)[1],targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)
        valid_accuracy = num_correct / num_examples
        
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,valid_loss, num_correct / num_examples))
        
        train_loss_list.append(training_loss)
        valid_loss_list.append(valid_loss)
        valid_accuracy_list.append(valid_accuracy)
        epoch_list.append(epoch)
        
    return train_loss_list,valid_loss_list,valid_accuracy_list, epoch_list
  • 定义损失函数、优化器、运行平台
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_fn = torch.nn.CrossEntropyLoss()

if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

model.to(device)
net(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=896, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=2, bias=True)
)
  • 查看模型参数数量
def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}
get_parameter_number(model)
{'Total': 67266, 'Trainable': 67266}
  • 训练
train_loss_list,valid_loss_list,valid_accuracy_list ,epoch_list = \
   train(model,optimizer, loss_fn, train_data_loader, val_data_loader,epoches=100, device=device)
Epoch: 1, Training Loss: 0.76, Validation Loss: 0.71, accuracy = 0.44
Epoch: 2, Training Loss: 0.70, Validation Loss: 0.69, accuracy = 0.44
Epoch: 3, Training Loss: 0.68, Validation Loss: 0.68, accuracy = 0.52
Epoch: 4, Training Loss: 0.68, Validation Loss: 0.67, accuracy = 0.72
Epoch: 5, Training Loss: 0.67, Validation Loss: 0.65, accuracy = 0.70
Epoch: 6, Training Loss: 0.66, Validation Loss: 0.64, accuracy = 0.70
Epoch: 7, Training Loss: 0.66, Validation Loss: 0.62, accuracy = 0.72
Epoch: 8, Training Loss: 0.65, Validation Loss: 0.61, accuracy = 0.72
Epoch: 9, Training Loss: 0.64, Validation Loss: 0.59, accuracy = 0.68
Epoch: 10, Training Loss: 0.63, Validation Loss: 0.58, accuracy = 0.70
Epoch: 11, Training Loss: 0.61, Validation Loss: 0.56, accuracy = 0.68
Epoch: 12, Training Loss: 0.60, Validation Loss: 0.55, accuracy = 0.68
Epoch: 13, Training Loss: 0.59, Validation Loss: 0.54, accuracy = 0.68
Epoch: 14, Training Loss: 0.58, Validation Loss: 0.53, accuracy = 0.66
Epoch: 15, Training Loss: 0.58, Validation Loss: 0.52, accuracy = 0.66
Epoch: 16, Training Loss: 0.57, Validation Loss: 0.51, accuracy = 0.68
Epoch: 17, Training Loss: 0.56, Validation Loss: 0.50, accuracy = 0.68
Epoch: 18, Training Loss: 0.56, Validation Loss: 0.50, accuracy = 0.68
Epoch: 19, Training Loss: 0.55, Validation Loss: 0.49, accuracy = 0.72
Epoch: 20, Training Loss: 0.55, Validation Loss: 0.49, accuracy = 0.76
Epoch: 21, Training Loss: 0.55, Validation Loss: 0.49, accuracy = 0.76
Epoch: 22, Training Loss: 0.54, Validation Loss: 0.48, accuracy = 0.78
Epoch: 23, Training Loss: 0.54, Validation Loss: 0.48, accuracy = 0.78
Epoch: 24, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 25, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 26, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 27, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 28, Training Loss: 0.52, Validation Loss: 0.46, accuracy = 0.78
Epoch: 29, Training Loss: 0.52, Validation Loss: 0.46, accuracy = 0.78
Epoch: 30, Training Loss: 0.52, Validation Loss: 0.46, accuracy = 0.78
Epoch: 31, Training Loss: 0.51, Validation Loss: 0.45, accuracy = 0.78
Epoch: 32, Training Loss: 0.51, Validation Loss: 0.45, accuracy = 0.78
Epoch: 33, Training Loss: 0.51, Validation Loss: 0.45, accuracy = 0.78
Epoch: 34, Training Loss: 0.50, Validation Loss: 0.45, accuracy = 0.78
Epoch: 35, Training Loss: 0.50, Validation Loss: 0.44, accuracy = 0.78
Epoch: 36, Training Loss: 0.50, Validation Loss: 0.44, accuracy = 0.80
Epoch: 37, Training Loss: 0.49, Validation Loss: 0.44, accuracy = 0.80
Epoch: 38, Training Loss: 0.49, Validation Loss: 0.44, accuracy = 0.80
Epoch: 39, Training Loss: 0.49, Validation Loss: 0.43, accuracy = 0.82
Epoch: 40, Training Loss: 0.48, Validation Loss: 0.43, accuracy = 0.82
Epoch: 41, Training Loss: 0.48, Validation Loss: 0.43, accuracy = 0.82
Epoch: 42, Training Loss: 0.48, Validation Loss: 0.42, accuracy = 0.82
Epoch: 43, Training Loss: 0.47, Validation Loss: 0.42, accuracy = 0.82
Epoch: 44, Training Loss: 0.47, Validation Loss: 0.42, accuracy = 0.82
Epoch: 45, Training Loss: 0.46, Validation Loss: 0.41, accuracy = 0.82
Epoch: 46, Training Loss: 0.46, Validation Loss: 0.41, accuracy = 0.82
Epoch: 47, Training Loss: 0.46, Validation Loss: 0.41, accuracy = 0.82
Epoch: 48, Training Loss: 0.45, Validation Loss: 0.40, accuracy = 0.82
Epoch: 49, Training Loss: 0.45, Validation Loss: 0.40, accuracy = 0.86
Epoch: 50, Training Loss: 0.44, Validation Loss: 0.40, accuracy = 0.88
Epoch: 51, Training Loss: 0.44, Validation Loss: 0.39, accuracy = 0.88
Epoch: 52, Training Loss: 0.43, Validation Loss: 0.39, accuracy = 0.88
Epoch: 53, Training Loss: 0.43, Validation Loss: 0.38, accuracy = 0.88
Epoch: 54, Training Loss: 0.42, Validation Loss: 0.38, accuracy = 0.88
Epoch: 55, Training Loss: 0.42, Validation Loss: 0.38, accuracy = 0.88
Epoch: 56, Training Loss: 0.41, Validation Loss: 0.37, accuracy = 0.88
Epoch: 57, Training Loss: 0.41, Validation Loss: 0.37, accuracy = 0.88
Epoch: 58, Training Loss: 0.40, Validation Loss: 0.36, accuracy = 0.88
Epoch: 59, Training Loss: 0.40, Validation Loss: 0.36, accuracy = 0.88
Epoch: 60, Training Loss: 0.39, Validation Loss: 0.36, accuracy = 0.88
Epoch: 61, Training Loss: 0.39, Validation Loss: 0.35, accuracy = 0.90
Epoch: 62, Training Loss: 0.38, Validation Loss: 0.35, accuracy = 0.90
Epoch: 63, Training Loss: 0.38, Validation Loss: 0.34, accuracy = 0.90
Epoch: 64, Training Loss: 0.37, Validation Loss: 0.34, accuracy = 0.90
Epoch: 65, Training Loss: 0.37, Validation Loss: 0.33, accuracy = 0.90
Epoch: 66, Training Loss: 0.36, Validation Loss: 0.33, accuracy = 0.92
Epoch: 67, Training Loss: 0.35, Validation Loss: 0.32, accuracy = 0.92
Epoch: 68, Training Loss: 0.35, Validation Loss: 0.32, accuracy = 0.92
Epoch: 69, Training Loss: 0.34, Validation Loss: 0.31, accuracy = 0.92
Epoch: 70, Training Loss: 0.33, Validation Loss: 0.31, accuracy = 0.94
Epoch: 71, Training Loss: 0.33, Validation Loss: 0.30, accuracy = 0.94
Epoch: 72, Training Loss: 0.32, Validation Loss: 0.30, accuracy = 0.94
Epoch: 73, Training Loss: 0.32, Validation Loss: 0.29, accuracy = 0.94
Epoch: 74, Training Loss: 0.31, Validation Loss: 0.29, accuracy = 0.94
Epoch: 75, Training Loss: 0.30, Validation Loss: 0.28, accuracy = 0.94
Epoch: 76, Training Loss: 0.30, Validation Loss: 0.28, accuracy = 0.94
Epoch: 77, Training Loss: 0.29, Validation Loss: 0.27, accuracy = 0.94
Epoch: 78, Training Loss: 0.28, Validation Loss: 0.27, accuracy = 0.94
Epoch: 79, Training Loss: 0.28, Validation Loss: 0.26, accuracy = 0.94
Epoch: 80, Training Loss: 0.27, Validation Loss: 0.26, accuracy = 0.94
Epoch: 81, Training Loss: 0.26, Validation Loss: 0.25, accuracy = 0.94
Epoch: 82, Training Loss: 0.26, Validation Loss: 0.25, accuracy = 0.94
Epoch: 83, Training Loss: 0.25, Validation Loss: 0.24, accuracy = 0.94
Epoch: 84, Training Loss: 0.25, Validation Loss: 0.24, accuracy = 0.94
Epoch: 85, Training Loss: 0.24, Validation Loss: 0.23, accuracy = 0.94
Epoch: 86, Training Loss: 0.23, Validation Loss: 0.23, accuracy = 0.94
Epoch: 87, Training Loss: 0.23, Validation Loss: 0.22, accuracy = 0.94
Epoch: 88, Training Loss: 0.22, Validation Loss: 0.22, accuracy = 0.94
Epoch: 89, Training Loss: 0.22, Validation Loss: 0.21, accuracy = 0.94
Epoch: 90, Training Loss: 0.21, Validation Loss: 0.21, accuracy = 0.94
Epoch: 91, Training Loss: 0.20, Validation Loss: 0.20, accuracy = 0.94
Epoch: 92, Training Loss: 0.20, Validation Loss: 0.20, accuracy = 0.96
Epoch: 93, Training Loss: 0.19, Validation Loss: 0.19, accuracy = 0.96
Epoch: 94, Training Loss: 0.19, Validation Loss: 0.19, accuracy = 0.96
Epoch: 95, Training Loss: 0.18, Validation Loss: 0.18, accuracy = 0.96
Epoch: 96, Training Loss: 0.18, Validation Loss: 0.18, accuracy = 0.96
Epoch: 97, Training Loss: 0.17, Validation Loss: 0.17, accuracy = 0.96
Epoch: 98, Training Loss: 0.17, Validation Loss: 0.17, accuracy = 0.96
Epoch: 99, Training Loss: 0.16, Validation Loss: 0.16, accuracy = 0.96
Epoch: 100, Training Loss: 0.16, Validation Loss: 0.16, accuracy = 0.96
  • 模型保存
torch.save(model,"K:\\classifier3.pt")#保存完整模型
  • 模型加载
load_model = torch.load("K:\\classifier3.pt")
  • 预测
img_path = r"K:\imageData\polarity\data3\val\pos\00002.bmp"
#img_path = r"K:\imageData\polarity\data3\val\neg\00001.bmp"
labels = ["neg","pos"]
img = Image.open(img_path)
img = img_transforms(img).to(device)
img = torch.unsqueeze(img,0)

model.eval()
prediction = F.softmax(model(img),dim=1)
prediction = prediction.argmax()
print(labels[prediction])
pos
  • 网络模型可视化
import netron
netron.start("K:\\classifier3.pt")
Serving 'K:\classifier3.pt' at http://localhost:8080





('localhost', 8080)
  • 训练过程可视化
def visualize(train_loss,val_loss,val_acc):
    train_loss = np.array(train_loss)
    val_loss = np.array(val_loss)
    val_acc = np.array(val_acc)
    plt.grid(True)
    plt.xlabel("epoch")
    plt.ylabel("value")
    plt.title("train_loss and valid_acc")
    plt.plot(np.arange(len(val_acc)),val_acc, label=r"valid_acc",c="g")
    plt.plot(np.arange(len(train_loss)),train_loss,label=r"train_loss",c="r")
    plt.legend()
    plt.savefig("K:\\a.png")
    
visualize(train_loss_list,valid_loss_list,valid_accuracy_list)
在这里插入图片描述

注:从图像中的训练损失和验证准确度来看,训练的轮次还应该再增加,因为训练损失还在下降,验证准确度还在上升,没有到达饱和状态。

相关文章

  • pytorch图像分类问题完整流程

    简介:本篇文章展示pytorch做图像分类的完整过程。因为在我的应用场景下图片特征简单,对计算速度有要求,所以把网...

  • pytorch图像分类完整过程

    简介:本篇文章展示pytorch做图像分类的完整过程。因为在我的应用场景下图片特征简单,对计算速度有要求,所以把网...

  • Pytorch图像分类

    1、Datasets 这段代码可以实现从图片读入数据,文件夹名为label。 2、Pytorch训练 3、将自己的...

  • 使用pytorch深度学习框架实现mnist数据集的图像分类

    此文章是使用pytorch实现mnist手写字体的图像分类。利用pytorch内置函数mnist下载数据,同时利用...

  • pytorch之图像分类

    满心欢喜的来跑这个图像分类,上来就报了个错。安装torchvision 疯狂报这个错:raise NotSuppo...

  • Pytorch实战-图像分类

    用图像实现Pytorch图像分类(一) 总结:使用预训练网络有什么意义当我们人类看到图像时,可以识别线条和形状。正...

  • 使用PyTorch建立图像分类模型

    概述 在PyTorch中构建自己的卷积神经网络(CNN)的实践教程 我们将研究一个图像分类问题——CNN的一个经典...

  • Pytorch 分类问题

    train loss 不断下降,test loss不断下降,说明网络仍在学习; train loss 不断下降,t...

  • 图像分类

    图像分类入门 -图像分类的概念 背景与意义 所谓图像分类问题,就是已有固定的分类标签集合,然后对于输入的图像,从分...

  • pytorch学习(五)—图像的加载/读取方式

    图像加载问题 使用pytorch制作图像数据集时,需要将存储在磁盘、硬盘的图像读取到内存中,涉及到图像I/O问题。...

网友评论

    本文标题:pytorch图像分类问题完整流程

    本文链接:https://www.haomeiwen.com/subject/esatsltx.html