美文网首页
Mnist数据集的训练

Mnist数据集的训练

作者: 放羊娃华振 | 来源:发表于2020-05-17 13:30 被阅读0次

    一、环境的安装

    如果要训练mnist数据集,需要安装TensorFlow和Python环境,简单的就是在本地把环境都安装好,这个网上有很多的教程,我就不赘述了。如果希望使用我测试使用的方式可以参考:Anacanda在Pycharm的使用

    二、训练代码

    目前我训练的结果还不是百分百的完美,所以我就先贴出来一个吧,后续还需要不断优化。

    # coding=utf-8
    # 单隐层SoftMax Regression分类器:训练和保存模型模块
    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    print('tensortflow:{0}'.format(tf.__version__))
    
    # Mnist_dat目录下面需要放mnist训练数据
    mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)
    
    
    x = tf.placeholder(tf.float32, [None, 784], name='x_input')  # 输入节点名:x_input
    y_ = tf.placeholder(tf.float32, [None, 10], name='y_input')
    
    
    dense1 = tf.layers.dense(inputs=x,
                             units=512,
                             activation=tf.nn.relu,
                             kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                             kernel_regularizer=tf.nn.l2_loss)
    dense2 = tf.layers.dense(inputs=dense1,
                             units=512,
                             activation=tf.nn.relu,
                             kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                             kernel_regularizer=tf.nn.l2_loss)
    logits = tf.layers.dense(inputs=dense2,
                             units=10,
                             activation=None,
                             kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                             kernel_regularizer=tf.nn.l2_loss, name='logit')
    y = tf.nn.softmax(logits, name='final_result')
    
    # 定义损失函数和优化方法
    with tf.name_scope('loss'):
        # loss = -tf.reduce_sum(y_ * tf.log(y))
        loss = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))
    with tf.name_scope('train_step'):
        train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
        print(train_step)
    # 初始化
    sess = tf.InteractiveSession()
    init = tf.global_variables_initializer()
    sess.run(init)
    
    # 训练
    print("开始训练。。。")
    for step in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(50)
        train_step.run({x: batch_xs, y_: batch_ys})
    print("训练完成。。。")
    # 测试模型准确率
    pre_num = tf.argmax(y, 1, output_type='int32', name="output")  # 输出节点名:output
    correct_prediction = tf.equal(pre_num, tf.argmax(y_, 1, output_type='int32'))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    a = accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})
    print('测试正确率:{0}'.format(a))
    
    savePbPath = '/Users/tal/Desktop/tf/pb/mnist_test.pb'
    # 保存训练好的模型
    # 形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
    output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
    # with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:  # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
    #     f.write(output_graph_def.SerializeToString())
    with tf.gfile.GFile(savePbPath, 'wb') as fd:
        fd.write(output_graph_def.SerializeToString())
    sess.close()
    

    三、参考资料

    1、Tensorflow简单线性模型(softmax regression)识别Mnist手写数字:
    http://www.enpeizhao.com/?p=179
    2、参考代码1:https://github.com/ChouBaoDxs/mnist_testdemo
    3、参考代码2:https://github.com/ChaoflyLi/MnistAndroid
    4、参考代码3:https://github.com/ChaoflyLi/MnistToAndroid
    5、参考博客:https://blog.csdn.net/chaofeili/article/details/89374324
    6、我自己整理代码:MnistTf160

    相关文章

      网友评论

          本文标题:Mnist数据集的训练

          本文链接:https://www.haomeiwen.com/subject/vozjohtx.html