美文网首页
CV-字符识别模型

CV-字符识别模型

作者: 一技破万法 | 来源:发表于2020-05-19 11:32 被阅读0次

Pytorch构建CNN模型

Pytorch中构建CNN模型只需要定义好模型的参数和正向传播就可以,Pytorch会根据正向传播自动计算反向传播。这里构建一个简单的CNN,包含两个卷积层,六个全连接层。

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
#定义模型
class SVHN_Model1(nn.Module):
    def __init__(self):
        super(SVHN_Model1, self).__init__()
        # CNN提取特征模块
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2)),
            nn.ReLU(),  
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2)),
            nn.ReLU(), 
            nn.MaxPool2d(2),
        )
        # 
        self.fc1 = nn.Linear(32*3*7, 11)
        self.fc2 = nn.Linear(32*3*7, 11)
        self.fc3 = nn.Linear(32*3*7, 11)
        self.fc4 = nn.Linear(32*3*7, 11)
        self.fc5 = nn.Linear(32*3*7, 11)
        self.fc6 = nn.Linear(32*3*7, 11)
    
    def forward(self, img):        
        feat = self.cnn(img)
        feat = feat.view(feat.shape[0], -1)
        c1 = self.fc1(feat)
        c2 = self.fc2(feat)
        c3 = self.fc3(feat)
        c4 = self.fc4(feat)
        c5 = self.fc5(feat)
        c6 = self.fc6(feat)
        return c1, c2, c3, c4, c5, c6
    
model = SVHN_Model1()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(model.parameters(), 0.005)

loss_plot, c0_plot = [], []
# 迭代10个Epoch
for epoch in range(10):
    for data in train_loader:
        c0, c1, c2, c3, c4, c5 = model(data[0])
        loss = criterion(c0, data[1][:, 0]) + \
                criterion(c1, data[1][:, 1]) + \
                criterion(c2, data[1][:, 2]) + \
                criterion(c3, data[1][:, 3]) + \
                criterion(c4, data[1][:, 4]) + \
                criterion(c5, data[1][:, 5])
        loss /= 6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_plot.append(loss.item())
        c0_plot.append((c0.argmax(1) == data[1][:, 0]).sum().item()*1.0 / c0.shape[0])
        
    print(epoch)
个别代码注释分析

torch.utils.data.DataLoader() #pytorch的一种数据类型
属性:
dataset:输入数据类型
batch_size(int):每次输入数据的行数
shuffle:数据类型bool,是否洗牌
collate_fn:将一小段数据合并成数据列表,默认Forse。True的话就是在系统返回前将张量数据复制到CUDA中
num_workers(int):进程数量
drop_last:数据类型bool,默认False,设置batch_size后最后一批数据可能数量不够,询问是否舍去。

nn.model #nn中的一个类,包含网络各层的定义及forward方法
super(d, self).init() #初始化继承父类的构造函数
nn.Sequential() #一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
nn.CrossEntropyLoss() #多用于多分类问题的交叉熵。
loss(x,class) = -log\dfrac{e^{(x[class])}}{\sum_{j}e^{(x[j])}} = -x[class] + log(\sum_j e^{(x[j])})
若含有权重函数,则公式为:
loss(x,class) = weight[class](-x[class] + log(\sum_j e^{(x[j])}))
torch.optim.Adam(model.parameters(), 0.005) #第一个参数是params - 待优化参数的iterable或者是定义了参数组的dict;第二个参数是lr-默认为1e-3
optimizer.zero_grad() #将梯度初始化为0,因为一个batch的loss关于weight的导数是将所有的sample的loss关于weight的导数的累加和
loss.backward() #反向传播求梯度
optimizer.step() #更新参数

相关文章

  • CV-字符识别模型

    Pytorch构建CNN模型 Pytorch中构建CNN模型只需要定义好模型的参数和正向传播就可以,Pytorch...

  • CV-模型集成

    集成学习方法 集成学习能够提高预测精度,常见的集成学习方法有stacking、bagging和boosting,同...

  • CV-模型训练与验证

    在模型的训练过程中,模型训练误差逐渐降低,但是模型泛化效果较差,这种现象称为过拟合。所以为了提高模型的泛化能力,我...

  • Datawhale 零基础入门CV赛事-Task3 字符识别模型

    3 字符识别模型 本章将会讲解卷积神经网络(Convolutional Neural Network, CNN)的...

  • CV-数据读取与数据扩增

    首先考虑使用[定长字符识别]思路来构建模型。 1. 图像读取 Pillow OpenCV 1.1 Pillow ...

  • Task 3 字符识别模型

    这三天了解了一下CNN(这个打卡写的我想吐 完全是划水 从task2 到task3 跳度有点大)对于这个task ...

  • Task3 字符识别模型

    本次学习目标: 学习CNN基础和原理 使用pytorch框架构建CNN模型,并完成训练 1、卷积神经网络(Conv...

  • Task03 字符识别模型

    一、CNN模型 CNN,又称卷积神经网络,它是一种前馈的神经网络,在图像识别领域有着巨大的应用。 二、如何理解卷积...

  • Task03:字符识别模型

    0. CNN原理和发展 CNN由卷积(convolution)、池化(pooling)、非线性激活函数(non-l...

  • Task3 字符识别模型

    前面学了背景及数据读取,今天开始模型的部分了。使用常用的卷积神经网络CNN,搭建个分类模型。 CNN介绍 CNN是...

网友评论

      本文标题:CV-字符识别模型

      本文链接:https://www.haomeiwen.com/subject/ltqmohtx.html