美文网首页
TensorFlow学习7:输入图片,预测结果

TensorFlow学习7:输入图片,预测结果

作者: 崔业康 | 来源:发表于2018-06-15 14:12 被阅读0次

    代码处理过程

    1,模型的要求是黑底白字,但输入的图是白底黑字,所以需要对每个像素点的值改为255减去原值以得到互补的反色
    2,对图片做二值化处理
    3,把图片形状拉成1行784列,并把值变成浮点型(要求像素点是0-1之间的浮点数)
    4,计算求得输出y,y的最大值所对应的列表索引号就是预测结果

    示例代码

    #coding:utf-8
    #将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值
    def restore_model(testPicArr):
        #创建一个默认图,在该图中执行以下操作
        with tf.Graph().as_default() as tg:
            x=tf.placeholder(tf.float32,[None,mnist_forword.INPUT_NONE])
            y=mnist_forword.mnist_forword(x,None)
            #得到概率最大的预测值
            preValue=tf.argmax(y,1)
    
            #实现滑动平均模型,参数MOVING_AVERAGE_DECAY用于控制模型更新的速度
            #训练过程中会对每一个变量维护一个影子变量,这个影子变量的初始值
            #就是相应变量的初始值,每次变量更新时,影子变量就会随之更新
            variable_averages=tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
            variable_to_restore=variable_averages.variable_to_restore()
            saver=tf.train.Saver(variable_to_restore)
    
            with tf.session() as sess:
                #通过checkpoint文件定位到最新保存的模型
                ckpt=tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess,ckpt.model_checkpoint_path)
    
                    preValue=sess.run(preValue,feed_dict={x:testPicArr})
                    return preValue
                else:
                    print("No checkpoint file found")
                    return -1
    
    #预处理函数,包括resize,转变灰度图,二值化操作
    def pre_pic(picName):
        img=Image.open(picName)
        prIm=img.resize((28,28),Image.ANTIALIAS)
        im_arr=np.array(reIm.convert('L'))
        #设定合理的阙值
        threshold=50
        for i in range(28):
            for j in range(28):
                im_arr[i][j]=255-im_arr[i][j]
                if(im_arr[i][j]<threshold):
                    im_arr[i][j]=0
                else:
                    im_arr[i][j]=255
        nm_arr=im_arr.reshape([1,784])
        nm_arr=nm_arr.astype(np.float32)
        img_ready=np.multiply(nm_arr,1.0/255.0)
    
        return img_ready
    
    

    参考:人工智能实践:Tensorflow笔记

    相关文章

      网友评论

          本文标题:TensorFlow学习7:输入图片,预测结果

          本文链接:https://www.haomeiwen.com/subject/jbqmeftx.html