美文网首页
手写数字识别

手写数字识别

作者: 闫_锋 | 来源:发表于2018-05-23 16:36 被阅读72次
  • 图像识别(Image Recognition)是指利用计算机对图像进行处理、分析和理解,以识别各种不同模式的目标和对像的技术。

  • 图像识别的发展经历了三个阶段:文字识别、数字图像处理与识别、物体识别。机器学习领域一般将此类识别问题转化为分类问题。

  • 手写识别是常见的图像识别任务。计算机通过手写体图片来识别出图片中的字,与印刷字体不同的是,不同人的手写体风格迥异,大小不一,造成了计算机对手写识别任务的一些困难。

  • 数字手写体识别由于其有限的类别(0~9共10个数字)成为了相对简单的手写识别任务。DBRHD和MNIST是常用的两个数字手写识别数据集。

已有许多模型在MNIST或DBRHD数据集上进行了实验,有些模型对数据集进行了偏斜矫正,甚至在数据集上进行了人为的扭曲、偏移、缩放及失真等操作以获取更加多样性的
样本,使得模型更具有泛化性。

  • 常用于数字手写体的分类器:
    1. 线性分类器
    2. K最近邻分类器
    3. Boosted Stumps
    4. 非线性分类器
    5. SVM
    6. 多层感知器
    7. 卷积神经网络
  • 后续任务:利用全连接的神经网络实现手写识别的任务

MLP输出:“one-hot vectors”

  • 一个one-hot向量除了某一位的数字是1以外其余各维度数字都是0。
  • 图片标签将表示成一个只有在第n维度(从0开始)数字为1的10维向量。比如,标签0将表示成[1,0,0,0,0,0,0,0,0,0,0]。即, MLP输出层具有10个神经元。
import numpy as np  # 导入numpy工具包
from os import listdir  # 使用listdir模块,用于访问本地文件
from sklearn.neural_network import MLPClassifier


def img2vector(fileName):
    retMat = np.zeros([1024], int)  # 定义返回的矩阵,大小为1*1024
    fr = open(fileName)  # 打开包含32*32大小的数字文件
    lines = fr.readlines()  # 读取文件的所有行
    for i in range(32):  # 遍历文件所有行
        for j in range(32):  # 并将01数字存放在retMat中
            retMat[i * 32 + j] = lines[i][j]
    return retMat


def readDataSet(path):
    fileList = listdir(path)  # 获取文件夹下的所有文件
    numFiles = len(fileList)  # 统计需要读取的文件的数目
    dataSet = np.zeros([numFiles, 1024], int)  # 用于存放所有的数字文件
    hwLabels = np.zeros([numFiles, 10])  # 用于存放对应的one-hot标签
    for i in range(numFiles):  # 遍历所有的文件
        filePath = fileList[i]  # 获取文件名称/路径
        digit = int(filePath.split('_')[0])  # 通过文件名获取标签
        hwLabels[i][digit] = 1.0  # 将对应的one-hot标签置1
        dataSet[i] = img2vector(path + '/' + filePath)  # 读取文件内容
    return dataSet, hwLabels


# read dataSet
train_dataSet, train_hwLabels = readDataSet('trainingDigits')

clf = MLPClassifier(hidden_layer_sizes=(100,),
                    activation='logistic', solver='adam',
                    learning_rate_init=0.0001, max_iter=2000)
print(clf)
clf.fit(train_dataSet, train_hwLabels)

# read  testing dataSet
dataSet, hwLabels = readDataSet('testDigits')
res = clf.predict(dataSet)  # 对测试集进行预测
error_num = 0  # 统计预测错误的数目
num = len(dataSet)  # 测试集的数目
for i in range(num):  # 遍历预测结果
    # 比较长度为10的数组,返回包含01的数组,0为不同,1为相同
    # 若预测结果与真实结果相同,则10个数字全为1,否则不全为1
    if np.sum(res[i] == hwLabels[i]) < 10:
        error_num += 1
print("Total num:", num, " Wrong num:", \
      error_num, "  WrongRate:", error_num / float(num))

相关文章

网友评论

      本文标题:手写数字识别

      本文链接:https://www.haomeiwen.com/subject/ajvljftx.html