美文网首页TensorFlow
TensorFlow基本步骤

TensorFlow基本步骤

作者: MRJOHN_CUIT | 来源:发表于2019-04-09 21:15 被阅读0次

    1.载入TensorFlow库

    import tensorflow as tf  //as tf 意思是重命名为tf
    

    2.创建新的InteractiveSession,使新创建的session注册为默认的session,之后的运算默认跑在这个session里面,不同session之间的数据和运算都是相互独立的

    sess = tf.InteractiveSession()
    

    3.创建一个Placeholder用户输入数据

    x  = tf.placeholder(tf.float32,[None,784])
    

    第一个参数是数据类型,第二个参数代表tensor的数据尺寸。其中的None代表输入的条数不限,784代表每条输入是一个784维的向量。
    4.给模型中的weights和biases创建Variable对象

    w = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    

    参数是什么意思????目前不懂
    5.实现算法
    由公式y=softmax(wx+b)得到代码

    y = tf.nn.softmax(tf.matmul(x,w)+b))
    

    其中soft是一个函数,包含在tf.nn下,tf.nn含有大量的神经网络的组件。tf.matmul是TensorFlow中的矩阵乘法函数
    6.描述模型对问题的分类精度——常用cross_entropy

    y_ = tf.placeholder(tf.float32,[None,10])  //用于输入真实的label
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
    

    7.定义优化算法进行训练

    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    

    TensorFlow会自动在后台进行运算。GradientDescentOptimizer是梯度下降算法,0.5是学习的速率,最后一个参数是优化目标。结果train_step为训练操作
    8.运行全局参数初始化器,并直接运行他的run方法

    tf.global_variables_initializer().run()
    

    9.迭代地执行训练操作train_step。

    for i in range(1000):
      batch_x, batch_y = mnist.train.next_batch(100)
      train_step.run({x:batch_xs, y_:batch_ys})
    

    其中参数100表示每次随机从训练集中抽取100条样本构成一个mini-batch,并feed给placeholder
    10.完成训练后对模型的准确率进行验证

    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    

    tf.argmax是一个从tensor中寻找最大值的序号。它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签。tf.argmax(y,1)是求各个预测的数字中概率最大的哪一个。得到的结果是一个布尔值数组
    11.将上面的到的结果转换为正确率

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval({x:mnist.test.images, y_:mnist.test.labels}))
    

    tf.cast(correct_prediction, tf.float32)将布尔数组转化为float32格式

    相关文章

      网友评论

        本文标题:TensorFlow基本步骤

        本文链接:https://www.haomeiwen.com/subject/gfcmiqtx.html