一、环境的安装
如果要训练mnist数据集,需要安装TensorFlow和Python环境,简单的就是在本地把环境都安装好,这个网上有很多的教程,我就不赘述了。如果希望使用我测试使用的方式可以参考:Anacanda在Pycharm的使用
二、训练代码
目前我训练的结果还不是百分百的完美,所以我就先贴出来一个吧,后续还需要不断优化。
# coding=utf-8
# 单隐层SoftMax Regression分类器:训练和保存模型模块
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.framework import graph_util
print('tensortflow:{0}'.format(tf.__version__))
# Mnist_dat目录下面需要放mnist训练数据
mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)
x = tf.placeholder(tf.float32, [None, 784], name='x_input') # 输入节点名:x_input
y_ = tf.placeholder(tf.float32, [None, 10], name='y_input')
dense1 = tf.layers.dense(inputs=x,
units=512,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
dense2 = tf.layers.dense(inputs=dense1,
units=512,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
logits = tf.layers.dense(inputs=dense2,
units=10,
activation=None,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss, name='logit')
y = tf.nn.softmax(logits, name='final_result')
# 定义损失函数和优化方法
with tf.name_scope('loss'):
# loss = -tf.reduce_sum(y_ * tf.log(y))
loss = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))
with tf.name_scope('train_step'):
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
print(train_step)
# 初始化
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
# 训练
print("开始训练。。。")
for step in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(50)
train_step.run({x: batch_xs, y_: batch_ys})
print("训练完成。。。")
# 测试模型准确率
pre_num = tf.argmax(y, 1, output_type='int32', name="output") # 输出节点名:output
correct_prediction = tf.equal(pre_num, tf.argmax(y_, 1, output_type='int32'))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
a = accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})
print('测试正确率:{0}'.format(a))
savePbPath = '/Users/tal/Desktop/tf/pb/mnist_test.pb'
# 保存训练好的模型
# 形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
# with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f: # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
# f.write(output_graph_def.SerializeToString())
with tf.gfile.GFile(savePbPath, 'wb') as fd:
fd.write(output_graph_def.SerializeToString())
sess.close()
三、参考资料
1、Tensorflow简单线性模型(softmax regression)识别Mnist手写数字:
http://www.enpeizhao.com/?p=179
2、参考代码1:https://github.com/ChouBaoDxs/mnist_testdemo
3、参考代码2:https://github.com/ChaoflyLi/MnistAndroid
4、参考代码3:https://github.com/ChaoflyLi/MnistToAndroid
5、参考博客:https://blog.csdn.net/chaofeili/article/details/89374324
6、我自己整理代码:MnistTf160
网友评论