在mnist数据集上搭建BP神经网络,完成在测试集上的训练,详情参看:基于TensorFlow的mnist数据集BP网络搭建
那么我们思考:1.能不能对实际图片进行预测? 2.能不能用自定义的图片数据集进行预测?
首先看问题1,为了方便训练在mnist_backward.py中加入断点续训,这样在恢复训练后能继续上次的训练轮数,不必再重新开始:
断点续训其中,tf.train.get_checkpoint_state(checkpoint_dir,latestfile=None)表示如果文件夹包含有效断点文件则返回该文件,saver.restore(sess,ckpt.model_checkpoint_path)是恢复当前会话,将ckpt中的最新的w,b付给当前会话中。
在完成断点续训后,我们继续解决问题1,对实际图片进行预测,总共分2部:1)输入实际图片,对图片进行预处理使其符合NN要求;2)复现NN喂入图片。
先看步骤1),我们对图片进行简单的预处理,灰度化和二值化,使其大小为28*28,像素值在0-1之间黑底白字,大小是1行784列的数组。
预处理步骤2)复现神经网络,
复现NN还是,首先创建默认图,输入x占位,调用forward输出得到概率最大的索引值即预测y,实现滑动平均,传入更新速度moving_average_decay,并把每次更新得到的w,b和影子值保存。开启会话,根据chekpoint文件找到最新模型,喂入图片并计算返回预测值。
这两个步骤实现后我们只需在应用模块中调用就可:
应用在训练模型达到一定准确度后,运行application,就可以完成对实际图片的预测:
在下一篇,继续分析问题2,对自定义的图片数据集进行训练。
新手学习,欢迎指教!!
网友评论