使用tensorflow提供的mnist数据集
```
import tensorflowas tf
import numpyas np
import matplotlib.pyplotas plt
import input_data
mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)
trainning = mnist.train.images
train_labels = mnist.train.labels
testing = mnist.test.images
test_labels = mnist.test.labels
# print(trainning.shape)
# print(train_labels.shape)
# print(testing.shape)
# print(test_labels.shape)
'''
Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz
(55000, 784)
(55000, 10)
(10000, 784)
(10000, 10)
'''
#初始化变量x,y
x = tf.placeholder("float32",[None,784])
y = tf.placeholder("float32",[None,10])
#w,b都是为0的矩阵
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
actv = tf.nn.softmax(tf.matmul(x,w)+b)
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
learning_rate =0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# sess = tf.Session()
# init = tf.global_variables_initializer()
# sess.run(init)
# print(sess.run(cost))
#对比预测值和真实值的索引是否一样
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(pred,"float32"))
init = tf.global_variables_initializer()
#sess = tf.InteractiveSession()
#迭代次数
train_epochs =50
#m每次迭代的样本
batch_size =100
display_step =5
sess= tf.Session()
sess.run(init)
for epochin range(train_epochs):
avg_cost =0
num_batch =int(mnist.train.num_examples/batch_size)
for iin range(num_batch):
batch_x,batch_y =mnist.train.next_batch(batch_size)
sess.run(optm,feed_dict={x:batch_x,y:batch_y})
feeds = {x:batch_x,y:batch_y}
avg_cost +=sess.run(cost,feed_dict=feeds)/num_batch
if (epoch+1) % display_step ==0:
feed_train = {x:batch_x,y:batch_y}
feed_test = {x:mnist.train.images,y:mnist.train.labels}
train_acc = sess.run(accr,feed_dict=feed_train)
test_acc = sess.run(accr,feed_dict=feed_test)
print("Epoch:%03d/%03d Cost:%.9f train_acc:%.3f test_acc :%.3f" % (epoch,train_epochs,avg_cost,train_acc,test_acc))
```
网友评论