美文网首页
简单的手写体识别

简单的手写体识别

作者: 枯燥一一cave | 来源:发表于2018-10-14 13:10 被阅读0次

import tensorflowas tf

import matplotlib.pylabas plt

import random

"""

通过input_data.read_data_sets函数生成的类会自动将MNIST数据集划分为train, validation和test三个数据集,

其中train这个集合内含有55000张图片,validation集合内含有5000张图片,这两个集合组成了MNIST本身提供的训练数据集。

test集合内有10000张图片,这些图片都来自与MNIST提供的测试数据集。处理后的每一张图片是一个长度为784的一维数组,

这个数组中的元素对应了图片像素矩阵中的每一个数字(28*28=784)。因为神经网络的输入是一个特征向量,

所以在此把一张二维图像的像素矩阵放到一个一维数组中可以方便tensorflow将图片的像素矩阵提供给神经网络的输入层。

像素矩阵中元素的取值范围为[0, 1],它代表了颜色的深浅。其中0表示白色背景,1表示黑色前景。为了方便使用随机梯度下降,

input_data.read_data_sets函数生成的类还提供了mnist.train.next_batch函数,

它可以从所有的训练数据中读取一小部分作为一个训练batch

"""

tf.set_random_seed(777)

#手写体数据集

from tensorflow.examples.tutorials.mnistimport input_data

mnist=input_data.read_data_sets("MNIST_data",one_hot=True)

nb_class=10

#图片属性定义

X=tf.placeholder(tf.float32,shape=[None,784])

Y=tf.placeholder(tf.float32,shape=[None,nb_class])

W=tf.Variable(tf.random_normal([784,nb_class]))

b=tf.Variable(tf.random_normal([nb_class]))

#函数该概率化

h=tf.nn.softmax(tf.matmul(X,W)+b)

cost=tf.reduce_mean(-tf.reduce_sum(Y*tf.log(h),axis=1))

opmintizer=tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)#梯度下降算法

#测试模型

#Test Model

incoret=tf.equal(tf.arg_max(h,1),tf.arg_max(Y,1))

#计算准确度

Accuracy=tf.reduce_mean(tf.cast(incoret,tf.float32))

traing_cs=15

batch_size=100

with tf.Session()as sess:

sess.run(tf.global_variables_initializer())

#训练模型

    for epochin range(traing_cs):

avg_cost=0

        total_bach=int(mnist.train.num_examples/batch_size)

for iin range(total_bach):

batch_xs,batch_ys=mnist.train.next_batch(batch_size)

c,_=sess.run([cost,opmintizer],feed_dict={X:batch_xs,Y:batch_ys})

avg_cost+=c/total_bach

print('Epoch',"%04d"%(epoch+1),'cost=',"{:.9f}".format(avg_cost))

print("learning finished")

# 测试模型

    print("Accuracy:",Accuracy.eval(session=sess,feed_dict={X:mnist.test.images,Y:mnist.test.labels}))

# 获取预测值

 r=random.randint(0,mnist.test.num_examples-1)

print("Labels:",sess.run(tf.arg_max(mnist.test.labels[r:r+1],1)))

print('Prediction:',sess.run(tf.argmax(h,1),feed_dict={X:mnist.test.images[r:r+1]}))

plt.imshow(mnist.test.images[r:r+1].reshape(28,28),cmap='Greys',interpolation='nearest')

plt.show()

相关文章

网友评论

      本文标题:简单的手写体识别

      本文链接:https://www.haomeiwen.com/subject/mtvzaftx.html