import tensorflowas tf
import matplotlib.pylabas plt
import random
"""
通过input_data.read_data_sets函数生成的类会自动将MNIST数据集划分为train, validation和test三个数据集,
其中train这个集合内含有55000张图片,validation集合内含有5000张图片,这两个集合组成了MNIST本身提供的训练数据集。
test集合内有10000张图片,这些图片都来自与MNIST提供的测试数据集。处理后的每一张图片是一个长度为784的一维数组,
这个数组中的元素对应了图片像素矩阵中的每一个数字(28*28=784)。因为神经网络的输入是一个特征向量,
所以在此把一张二维图像的像素矩阵放到一个一维数组中可以方便tensorflow将图片的像素矩阵提供给神经网络的输入层。
像素矩阵中元素的取值范围为[0, 1],它代表了颜色的深浅。其中0表示白色背景,1表示黑色前景。为了方便使用随机梯度下降,
input_data.read_data_sets函数生成的类还提供了mnist.train.next_batch函数,
它可以从所有的训练数据中读取一小部分作为一个训练batch
"""
tf.set_random_seed(777)
#手写体数据集
from tensorflow.examples.tutorials.mnistimport input_data
mnist=input_data.read_data_sets("MNIST_data",one_hot=True)
nb_class=10
#图片属性定义
X=tf.placeholder(tf.float32,shape=[None,784])
Y=tf.placeholder(tf.float32,shape=[None,nb_class])
W=tf.Variable(tf.random_normal([784,nb_class]))
b=tf.Variable(tf.random_normal([nb_class]))
#函数该概率化
h=tf.nn.softmax(tf.matmul(X,W)+b)
cost=tf.reduce_mean(-tf.reduce_sum(Y*tf.log(h),axis=1))
opmintizer=tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)#梯度下降算法
#测试模型
#Test Model
incoret=tf.equal(tf.arg_max(h,1),tf.arg_max(Y,1))
#计算准确度
Accuracy=tf.reduce_mean(tf.cast(incoret,tf.float32))
traing_cs=15
batch_size=100
with tf.Session()as sess:
sess.run(tf.global_variables_initializer())
#训练模型
for epochin range(traing_cs):
avg_cost=0
total_bach=int(mnist.train.num_examples/batch_size)
for iin range(total_bach):
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
c,_=sess.run([cost,opmintizer],feed_dict={X:batch_xs,Y:batch_ys})
avg_cost+=c/total_bach
print('Epoch',"%04d"%(epoch+1),'cost=',"{:.9f}".format(avg_cost))
print("learning finished")
# 测试模型
print("Accuracy:",Accuracy.eval(session=sess,feed_dict={X:mnist.test.images,Y:mnist.test.labels}))
# 获取预测值
r=random.randint(0,mnist.test.num_examples-1)
print("Labels:",sess.run(tf.arg_max(mnist.test.labels[r:r+1],1)))
print('Prediction:',sess.run(tf.argmax(h,1),feed_dict={X:mnist.test.images[r:r+1]}))
plt.imshow(mnist.test.images[r:r+1].reshape(28,28),cmap='Greys',interpolation='nearest')
plt.show()
网友评论