美文网首页
简单线性识别图片

简单线性识别图片

作者: Do_More | 来源:发表于2017-10-17 09:15 被阅读0次
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import time
import data_helpers

beginTime = time.time()
batch_size = 100
learning_rate = 0.005
max_steps = 1000

data_sets = data_helpers.load_data()

images_placeholder = tf.placeholder(tf.float32, shape=[None, 3072])
labels_placeholder = tf.placeholder(tf.int64, shape=[None])

weights = tf.Variable(tf.zeros([3072,10]))
biases = tf.Variable(tf.zeros([10]))

logits = tf.matmul(images_placeholder, weights) + biases

loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits=logits,
    labels=labels_placeholder))

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(logits,1),labels_placeholder)

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(max_steps):
        indices = np.random.choice(data_sets['images_train'].shape[0],batch_size)
        images_batch = data_sets['images_train'][indices]
        labels_batch = data_sets['labels_train'][indices]

        if i % 100 == 0:
            train_accuracy = sess.run(accuracy,feed_dict={
                images_placeholder: images_batch,
                labels_placeholder: labels_batch})
            print('Step {:5d}: training accuracy {:g}'.format(i, train_accuracy))

        sess.run(train_step,feed_dict={images_placeholder:images_batch,labels_placeholder:labels_batch})

    test_accuracy = sess.run(accuracy, feed_dict={
        images_placeholder: data_sets['images_test'],
        labels_placeholder: data_sets['labels_test']})
    print('Test accuracy {:g}'.format(test_accuracy))

endTime = time.time()
print("Total time: {:5.2f}s".format(endTime - beginTime))

相关文章

网友评论

      本文标题:简单线性识别图片

      本文链接:https://www.haomeiwen.com/subject/eumruxtx.html