美文网首页
Deep Learning with Tensorfow:cha

Deep Learning with Tensorfow:cha

作者: csuhan | 来源:发表于2019-02-20 17:24 被阅读0次

    Softmax 1 layer分类

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    from random import randint
    import numpy as np
    import matplotlib.pyplot as plt
    
    logs_patch = 'log_simple_stats_softmax'
    batch_size = 100 #批次大小
    learning_rate = 0.5 #学习率
    training_epochs = 10 #训练次数
    
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #Y_ = W*X+b
    X = tf.placeholder(tf.float32,[None,28,28,1],name="X")
    Y_ = tf.placeholder(tf.float32,[None,10],name="Y")
    W = tf.Variable(tf.zeros([784,10]),name="W")
    
    #将X展开为一维
    XX = tf.reshape(X,[-1,784])
    b = tf.Variable(tf.zeros([10]),name="b")
    
    #XX*W+b
    evidence = tf.matmul(XX,W)+b
    #softmax分类
    Y = tf.nn.softmax(evidence,name="output")
    #交叉熵作为损失函数
    cross_entropy = -tf.reduce_mean(Y_ * tf.log(Y))*1000.0
    #优化器,通过梯度下降减小损失
    train_step = tf.train.AdamOptimizer(0.005).minimize(cross_entropy)
    #计算精度,判断是否准确
    correct_prediction = tf.equal(tf.argmax(Y,1),tf.argmax(Y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    #统计
    tf.summary.scalar("cost",cross_entropy)
    tf.summary.scalar("accuracy",accuracy)
    summary_op = tf.summary.merge_all()
    
    with tf.Session() as sess:
        #初始化
        sess.run(tf.global_variables_initializer())
        writer = tf.summary.FileWriter(logs_patch,graph=tf.get_default_graph())
        for epoch in range(training_epochs):
            #批次训练
            batch_count = int(mnist.train.num_examples/batch_size)
            for i in range(batch_count):
                batch_x,batch_y = mnist.train.next_batch(batch_size)
                #传入值
                sess.run(train_step,feed_dict={XX:batch_x,Y_:batch_y})
            print("Epoch: ",epoch)
        #评估u精度
        print("Accuracy: ",accuracy.eval(feed_dict={X:np.reshape(mnist.test.images,[-1,28,28,1]),Y_:mnist.test.labels}))
        print("done")
        #预测
        num = randint(0,mnist.test.images.shape[0])#随机图像
        test_img = np.reshape(mnist.test.images[num],[28,28,1])
        test_label = mnist.test.labels[num]
        #预测label
        classification = sess.run(tf.argmax(Y,1),feed_dict={X:[test_img]})
        plt.imshow(np.reshape(test_img,[28,28]))
        plt.show()
        print("predict_label: ",classification[0])
        print("true_label: ",np.argmax(test_label))
    
    Extracting MNIST_data/train-images-idx3-ubyte.gz
    Extracting MNIST_data/train-labels-idx1-ubyte.gz
    Extracting MNIST_data/t10k-images-idx3-ubyte.gz
    Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
    Epoch:  0
    Epoch:  1
    Epoch:  2
    Epoch:  3
    Epoch:  4
    Epoch:  5
    Epoch:  6
    Epoch:  7
    Epoch:  8
    Epoch:  9
    Accuracy:  0.9218
    done
    predict_label:  5
    true_label:  5
    
    output_1_1.png

    相关文章

      网友评论

          本文标题:Deep Learning with Tensorfow:cha

          本文链接:https://www.haomeiwen.com/subject/clahyqtx.html