一:简单实现神经网络步骤:
1、定义算法公式,神经网络forward时的计算
2、定义loss,选定优化器,并指定优化器优化loss
3、训练。对数据进行迭代的训练,得到w,b等值
4、在训练集上或验证集上验证,并对结果进行评测
二:代码
深度学习时建立在庞大数据上的一门学科。所以代码的第一步就是加载数据。
import tensorflowas tf #导入tensorflow
#导入数据
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("mnist_data/",one_hot=True) # 从input_data中导入one_hot 编码的数据,什么是one_hot编码?其实我也不知道。度娘说:直观来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制 。
print(mnist.train.images.shape,mnist.train.labels.shape)
print(mnist.test.images.shape,mnist.test.labels.shape)
print(mnist.validation.images.shape,mnist.validation.labels.shape)#运行工程你就会在你的mnist_data 目录下看到对应的数据
接下来就是网络真正的第一步了,定义算法,算法公式怎么定义呢?我的理解是以一个线性方程为基础:y=wx+b。我们输入的数据为x,y为已经知道的结果,我们知道了y和x后去求w和b。通过很多的x求出最合适的w和b,然后在新的x上去验证我们求的w和b的准确率。
session=tf.InteractiveSession()#首先创建一个session
#获取输入的数据x。使用placeholder占用一个位置,后面动态输入,第一个参数为类型为浮点32位,第二个参数为矩阵,N行,784列的矩阵,784为28*28的图片数据
x=tf.placeholder(tf.float32,[None,784])
w=tf.Variable(tf.zeros([784,10]))#Variable为定义变量,zeros表示将w初始化为0,数据为784行,10列的矩阵,10列代表0-9个分类
b=tf.Variable(tf.zeros([10]))#Variable储存的参数可以持久化储存于显存中用于不断的迭代
y=tf.nn.softmax(tf.matmul(x,w)+b)#公式:y=wx+b 。以下为softmax公式。可以理解为将输入值x以e为底求指数的和作为分子,然后所有的作为分母得出一个分数,也可以看出是一个概率值。在0-1之间。它的作用就是将输入的x值分别与每一个类别的w和b运算后,得出对应的每一类的概率,最大概率的就为最后的输出结果。
创建损失函数:
y_=tf.placeholder(tf.float32,[None,10])#真实的标签
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))#损失函数。通过对y_和log(y)的乘积求和,最后求平均数。那么这个函数怎么就可以判断出预测与真实的差距呢?
创建优化器:
train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#创建优化器,随机梯度下降(SGD),优化损失函数,更新参数
tf.global_variables_initializer().run()#初始化模型的参数
开始迭代训练:
for i in range(1000):
batch_xs,batch_ys =mnist.train.next_batch(100)#每次取100个样本
train_step.run({x:batch_xs,y_:batch_ys})#分别赋给x,y,并运行。运行时,tensorflow会自动更新参数并优化参数。
接下来就是验证准确率:
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) #tf.argmax(y,1)得出预测的结果中概率最大的一个。tf.argmax(y_,1)真实的类别。equal 判断是否相等。
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
到此,数字识别就完成了。本人是菜鸟中的菜鸟,理解有问题的地方,请各位大神不吝赐教。
网友评论