美文网首页
卷积神经网络3完全解析

卷积神经网络3完全解析

作者: db13cf62b4e3 | 来源:发表于2018-10-28 21:01 被阅读0次

    #coding:utf-8

    import time

    import tensorflow as tf

    from tensorflow.examples.tutorials.mnist import input_data

    import mnist_lenet5_forward

    import mnist_lenet5_backward

    import numpy as np

    TEST_INTERVAL_SECS = 5

    def test(mnist):

        with tf.Graph().as_default() as g:

            x = tf.placeholder(tf.float32,[

                mnist.test.num_examples,

                mnist_lenet5_forward.IMAGE_SIZE,

                mnist_lenet5_forward.IMAGE_SIZE,

                mnist_lenet5_forward.NUM_CHANNELS])

            y_ = tf.placeholder(tf.float32, [None, mnist_lenet5_forward.OUTPUT_NODE])

            y = mnist_lenet5_forward.forward(x,False,None)

            ema = tf.train.ExponentialMovingAverage(mnist_lenet5_backward.MOVING_AVERAGE_DECAY)

            ema_restore = ema.variables_to_restore()

            saver = tf.train.Saver(ema_restore)

            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))#预测值y和实际值y_是否一致

            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#求平均得到预测准确率

            while True:

                with tf.Session() as sess:

                    ckpt = tf.train.get_checkpoint_state(mnist_lenet5_backward.MODEL_SAVE_PATH)

                    if ckpt and ckpt.model_checkpoint_path:

                        saver.restore(sess, ckpt.model_checkpoint_path)

                        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

                        reshaped_x = np.reshape(mnist.test.images,(

                        mnist.test.num_examples,

                    mnist_lenet5_forward.IMAGE_SIZE,

                    mnist_lenet5_forward.IMAGE_SIZE,

                    mnist_lenet5_forward.NUM_CHANNELS))

                        accuracy_score = sess.run(accuracy, feed_dict={x:reshaped_x,y_:mnist.test.labels})

                        print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))

                    else:

                        print('No checkpoint file found')

                        return

                time.sleep(TEST_INTERVAL_SECS)

    def main():

        mnist = input_data.read_data_sets("./data/", one_hot=True)

        test(mnist)

    if __name__ == '__main__':

        main()

    相关文章

      网友评论

          本文标题:卷积神经网络3完全解析

          本文链接:https://www.haomeiwen.com/subject/wlbmtqtx.html