Softmax 1 layer分类
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from random import randint
import numpy as np
import matplotlib.pyplot as plt
logs_patch = 'log_simple_stats_softmax'
batch_size = 100 #批次大小
learning_rate = 0.5 #学习率
training_epochs = 10 #训练次数
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
#Y_ = W*X+b
X = tf.placeholder(tf.float32,[None,28,28,1],name="X")
Y_ = tf.placeholder(tf.float32,[None,10],name="Y")
W = tf.Variable(tf.zeros([784,10]),name="W")
#将X展开为一维
XX = tf.reshape(X,[-1,784])
b = tf.Variable(tf.zeros([10]),name="b")
#XX*W+b
evidence = tf.matmul(XX,W)+b
#softmax分类
Y = tf.nn.softmax(evidence,name="output")
#交叉熵作为损失函数
cross_entropy = -tf.reduce_mean(Y_ * tf.log(Y))*1000.0
#优化器,通过梯度下降减小损失
train_step = tf.train.AdamOptimizer(0.005).minimize(cross_entropy)
#计算精度,判断是否准确
correct_prediction = tf.equal(tf.argmax(Y,1),tf.argmax(Y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#统计
tf.summary.scalar("cost",cross_entropy)
tf.summary.scalar("accuracy",accuracy)
summary_op = tf.summary.merge_all()
with tf.Session() as sess:
#初始化
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter(logs_patch,graph=tf.get_default_graph())
for epoch in range(training_epochs):
#批次训练
batch_count = int(mnist.train.num_examples/batch_size)
for i in range(batch_count):
batch_x,batch_y = mnist.train.next_batch(batch_size)
#传入值
sess.run(train_step,feed_dict={XX:batch_x,Y_:batch_y})
print("Epoch: ",epoch)
#评估u精度
print("Accuracy: ",accuracy.eval(feed_dict={X:np.reshape(mnist.test.images,[-1,28,28,1]),Y_:mnist.test.labels}))
print("done")
#预测
num = randint(0,mnist.test.images.shape[0])#随机图像
test_img = np.reshape(mnist.test.images[num],[28,28,1])
test_label = mnist.test.labels[num]
#预测label
classification = sess.run(tf.argmax(Y,1),feed_dict={X:[test_img]})
plt.imshow(np.reshape(test_img,[28,28]))
plt.show()
print("predict_label: ",classification[0])
print("true_label: ",np.argmax(test_label))
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Accuracy: 0.9218
done
predict_label: 5
true_label: 5
output_1_1.png
网友评论