之前一直用tensorflow处理mnist手写识别数据集,尽管天禹强烈安利keras,但是本人完全无动于衷,表示就算饿死也不会用keras。
本文参考了清华大学出版社的《tensorflow+keras深度学习人工智能时间应用》的大部分内容,并纠正了书中的一点点错误。若本文有任何谬误欢迎指正,本人将尽快更改并深表谢意。
以下使用Out[数字]:
的方式表示代码输出结果
Out[数字]:
代码输出
下载mnist数据
导入keras及相关模块
import numpy as np
import pandas as pd
from keras.utils import np_utils
np.random.seed(10)
导入mnist数据集
from keras.datasets import mnist
#下载及读取mnist
(X_train_image,y_train_label), \
(X_test_image,y_test_label)=mnist.load_data()
#查看mnist
print('train data', len(X_train_image))
print('test data',len(X_test_image))
# 查看训练数据
print('X_train_image:',X_train_image.shape)
print('y_train_label:',y_train_label.shape)
定义plot_image函数显示数字图像
为了能够显示images数字图像,我们创建下列plot_images函数
import matplotlib.pyplot as plt #导入matplotlib.pyplot模块
def plot_image(image): #定义plot_image函数,传入image作为参数
fig=plt.gcf() #设置显示图形的大小
fig.set_size_inches(2,2)
plt.imshow(image,cmap='binary') #使用plt.show显示图形,传入参数image是28x28的图形,cmap参数设置为binary,以黑白灰度显示
plt.show() #开始绘图
执行plt_image函数查看第0个数字图像
以下程序调用plot_image函数传入mnist.train.images[0],也就是训练数据集的第0项数据,从显示结果可以看出这是一个数字5的图形
plot_image(X_train_image[0])
查看多项数据images和label
创建plot_images_labels_prediction()函数
创建一个函数用来方便地查看数字图形,实际数字与预测结果。
import matplotlib.pyplot as plt
def plot_images_labels_prediction(image,labels,prediction,idx,num=10): #定义plot_images_labels_prediction
fig=plt.gcf() #设置显示图像的大小
fig.set_size_inches(12,14)
if num>25: num=25 #如果显示项数大于25,就设置为25,以免发生错误
for i in range(0,num): #for循环执行程序块内的程序代码,画出num个数字图形
ax=plt.subplot(5,5,1+i) #建立subgraph子图形为5行5列
ax.imshow(image[idx],cmap='binary') #画出subgraph图形
title= "label=" +str(labels[idx]) #设置子图形title,显示标签字段
if len(prediction)>0: #如果传入了预测结果
title+=",prdict"+str(prediction[idx]) #标题
ax.set_title(title,fontsize=10) #设置子图形的标题
ax.set_xticks([]);ax.set_yticks([]) #设置不显示刻度
idx+=1 #读取下一项
plt.show()
plot_images_labels_prediction函数需要传入下列参数:
数字的图像:image
实际值:labels
预测结果:prediction
开始显示的数据index: idx
要显示的数据项数(默认是10,不超过25):num
查看训练数据前10项数据
执行plot_images_labels_prediction函数显示前10项训练数据。输入X_train_image和y_train_label,不过目前还没有预测结果prediction,所以传入空list[],从第0项数据一直显示到第9项数据
plot_images_labels_prediction(X_train_image,y_train_label,[],0,10)
训练数据前10项
查看test测试数据
查看test测试数据项数,我们可以看到共计10000项数据
print('X_test_image:',X_test_image.shape)
print('y_test_label:',y_test_label.shape)
Out[10]:
X_test_image: (10000, 28, 28)
y_test_label: (10000,)
显示test测试数据
执行plot_images_labels_prediction显示前10项测试数据
plot_images_labels_prediction(X_test_image,y_test_label,[],0,10)
前10项测试数据
featrues数据预处理
features(数字图像的特征值)数据预处理可分为下列两个步骤:
(1)将原本28x28的数字图像以reshape转换为一维的向量,其长度是784,并且转换为float
(2)数字图像image的数字标准化
步骤一 查看image的shape
可以用下列指令查看每个数字图像的shape是28x28
print('X_train_image:',X_train_image.shape)
print('y_train_label:',y_train_label.shape)
Out[12]:
X_train_image: (60000, 28, 28)
y_train_label: (60000,)
步骤二 将image以reshape转换
下面的程序代码将原本28x28的二维数字以reshape转换为一维的向量,再以astype转换为float,共784个浮点数
x_train=X_train_image.reshape(60000,784).astype('float32')
x_test=X_test_image.reshape(10000,784).astype('float32')
可以用下列指令查看每一个数字图像是784个浮点数
print('x_train:',x_train.shape)
print('x_test:',x_test.shape)
步骤三 查看转换为一维向量的shape
print('x_train:',x_train.shape)
print('x_test:',x_test.shape)
Out[14]:
x_train: (60000, 784)
x_test: (10000, 784)
步骤四 查看image图像的内容
查看image第0项的内容
X_train_image[0]
从以上执行结果可知,每一个数字都是从0到255的值,代表图形每一个点灰度的深浅,其中大部分数字都是0。
步骤5 将数字图像images的数字标准化
image的数字标准化可以提高后续训练模型的准确率,因为image的数字是从0到255的值,所以最简单的标准化方式是除以255
x_train_normalize=x_train/255
x_test_normalize=x_test/255
步骤6 查看数字图像images数字标准化后的结果
label数据预处理
label(数字图像真实的值)标签字段原本是0-9的数字,必须以one-hot encoding(一位有效编码)转换为10个0或1的组合,例如数字7经过one-hot encoding转换后是0000000100,正好对应输出层的1个神经元
步骤一 查看原本的label标签字段
以下列指令来查看训练数据label标签字段的前5项训练数据,我们可以看到这是0-9的数字
y_train_label[:5]
Out[18]:
array([5, 0, 4, 1, 9], dtype=uint8)
步骤二 label标签字段进行one-hot encoding转换
下面的程序代码使用np_utils.to_categorical分别传入参数y_train_label(训练数据)和y_test_label(测试数据)的label标签字段,进行one-hot encoding转换
y_train_onehot=np_utils.to_categorical(y_train_label)
y_test_onehot=np_utils.to_categorical(y_test_label)
步骤三 查看进行one-hot encoding转换之后的label标签字段
进行one-hot encoding转换之后,查看训练数据label标签字段的前5项数据,我们可以看到转换后的结果
y_train_onehot[:5]
Out[20]:
array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)
网友评论