美文网首页
tensorflow深度学习之验证数据集(三)

tensorflow深度学习之验证数据集(三)

作者: baihualinxin | 来源:发表于2018-04-25 15:00 被阅读0次

    from PIL import Image

    import matplotlib.pyplot as plt

    #

    def get_one_image(train):

        '''Randomly pick one image from training data

        Return: ndarray

        '''

        n = len(train)

        ind = np.random.randint(0, n)

        img_dir = train[ind]

        image = Image.open(img_dir)

        plt.imshow(image)

        image = image.resize([64, 64])

        image = np.array(image)

        return image

    def evaluate_one_image():

        '''Test one image against the saved models and parameters

        '''

        # you need to change the directories to yours.

        train_dir = '/Users/Desktop/cd/cd/train/'  #存放验证的图片

        train, train_label = input_data.get_files(train_dir)

        image_array = get_one_image(train)

        with tf.Graph().as_default():

            BATCH_SIZE = 1

            N_CLASSES = 2

            image = tf.cast(image_array, tf.float32)

            image = tf.image.per_image_standardization(image)

            image = tf.reshape(image, [1, 64, 64, 3])

            logit = model.inference(image, BATCH_SIZE, N_CLASSES)

            logit = tf.nn.softmax(logit)

            x = tf.placeholder(tf.float32, shape=[64, 64, 3])

            # you need to change the directories to yours.

            logs_train_dir = '/Users/Desktop/cd/cd/logs' #数据集

            saver = tf.train.Saver()

            with tf.Session() as sess:

                print("Reading checkpoints...")

                ckpt = tf.train.get_checkpoint_state(logs_train_dir)

                if ckpt and ckpt.model_checkpoint_path:

                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

                    saver.restore(sess, ckpt.model_checkpoint_path)

                    print('Loading success, global_step is %s' % global_step)

                else:

                    print('No checkpoint file found')

                prediction = sess.run(logit, feed_dict={x: image_array})

                max_index = np.argmax(prediction)

                if max_index==0:

                    print('This is a car with possibility %.6f' %prediction[:, 0])

                else:

                    print('This is a not_car with possibility %.6f' %prediction[:, 1])

    相关文章

      网友评论

          本文标题:tensorflow深度学习之验证数据集(三)

          本文链接:https://www.haomeiwen.com/subject/eccylftx.html