美文网首页
tensorflow classification 分类学习

tensorflow classification 分类学习

作者: 朱宏飞 | 来源:发表于2018-08-25 21:54 被阅读0次

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist=input_data.read_data_sets('MNIST_data',one_hot=True)

#创建一个神经层

import numpy as np

def add_layer(inputs,in_size,out_size,activation_function=None):

    #这里激励函数默认为0,则可认为激励函数为线性

    Weights=tf.Variable(tf.random_normal([in_size,out_size]))

    biases=tf.Variable(tf.zeros([1,out_size]))+0.1

    Wx_plus_b=tf.matmul(inputs,Weights)+biases

    #还未被激活的值被赋予在Wx_plus_b中,下一步是去激活他

    if activation_function is None:

        outputs=Wx_plus_b

        #说明outputs是一个线性方程,不需要去activation_function.

        #使用activation_funtion将一个线性的函数变换成一个非线性的函数

    else:

    outputs=activation_function(Wx_plus_b)

    return outputs

def compute_accuracy(v_xs,v_ys):

    global prediction

    y_pre=sess.run(prediction,feed_dict={xs:v_xs})

    correct_prediction=tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))#tf.argmax()对矩阵按行或按列计算最大值所在位置,0表示按列,1表示按行

    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    #tf.cast(),是类型转换函数,转换成float32类型

    result=sess.run(accuracy)

    return result

#defind placeholder for inputs to network

    xs=tf.placeholder(tf.float32,[None,784])#28*28

    ys=tf.placeholder(tf.float32,[None,10])

#add output layer

    prediction=add_layer(xs,784,10,activation_function=tf.nn.softmax)

#softmax 一般用来做分类问题

#the error between prediction and real data

cross_entropy=tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction)),reducition_indices=[1]))

#交叉熵

train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

init=tf.initialize_all_variables()

with tf.Session() as sess:

    sess.run(init)

for i in range(1000):

    batch_xs,batch_ys=mnist.train.next_batch(100)#随机选取100组数据进行训练

    sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})

    if i%50==0:

        print(compute_accuracy(mnist.test.images,mnist.test.labels))

相关文章

网友评论

      本文标题:tensorflow classification 分类学习

      本文链接:https://www.haomeiwen.com/subject/llqniftx.html