美文网首页
tensorflow笔记:第五讲 全连接网络基础(+断点续训)

tensorflow笔记:第五讲 全连接网络基础(+断点续训)

作者: 九除以三还是三哦 | 来源:发表于2019-08-08 21:50 被阅读0次
  • 下载mnist数据集
    1. 首先四个文件下载打包好后直接拖入虚拟机中,下载地址在这里
文件 内容
train-images-idx3-ubyte.gz 训练集图片 - 55000 张 训练图片, 5000 张 验证图片
train-labels-idx1-ubyte.gz 训练集图片对应的数字标签
t10k-images-idx3-ubyte.gz 测试集图片 - 10000 张 图片
t10k-labels-idx1-ubyte.gz 测试集图片对应的数字标签

底层的源码将会执行下载、解压、重构图片和标签数据来组成以下的数据集对象:

数据集 目的
data_sets.train 55000 组 图片和标签, 用于训练。
data_sets.validation 5000 组 图片和标签, 用于迭代验证训练的准确性。
data_sets.test 10000 组 图片和标签, 用于最终测试训练的准确性。
  1. 然后需要写一个input_data.py文件
    代码在这里:https://testerhome.com/topics/18906
  2. 最后运行程序就好啦
  • 测试过的代码

mnist_forward.py

#coding:utf-8
#版本信息:ubuntu18.04  python3.6.8  tensorflow 1.14.0
#作者:九除以三还是三哦  如有错误,欢迎评论指正!!
import tensorflow as tf

INPUT_NODE=784   #网络输入节点为784个(代表每张输入图片的像素个数)
OUTPUT_NODE=10   #输出节点为10个(表示输出为数字0-9的十分类
LAYER1_NODE=500  #隐藏层节点500个

def get_weight(shape,regularizer):
    #参数满足截断正态分布,并使用正则化
    w=tf.Variable(tf.truncated_normal(shape,stddev=0.1))
    #将每个参数的正则化损失加到总损失中
    if regularizer !=None:tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape):
    b=tf.Variable(tf.zeros(shape))
    return b

def forward(x,regularizer):
    #由输入层到隐藏层的参数w1形状为[784,500]
    w1=get_weight([INPUT_NODE,LAYER1_NODE],regularizer)
    #由输入层到隐藏的偏置b1形状为长度500的一维数组
    b1=get_bias([LAYER1_NODE])
    #前向传播结构第一层为输入 x与参数 w1矩阵相乘加上偏置 b1 ,再经过relu函数 ,得到隐藏层输出 y1。
    y1=tf.nn.relu(tf.matmul(x,w1)+b1)
    #由隐藏层到输出层的参数w2形状为[500,10]
    w2=get_weight([LAYER1_NODE,OUTPUT_NODE],regularizer)
    #由隐藏层到输出的偏置b2形状为长度10的一维数组
    b2=get_bias([OUTPUT_NODE])
    #前向传播结构第二层为隐藏输出 y1与参 数 w2 矩阵相乘加上偏置 矩阵相乘加上偏置 b2,得到输出 y。
    #由于输出 。由于输出 y要经过softmax oftmax 函数,使其符合概率分布,故输出y不经过 relu函数
    y=tf.matmul(y1,w2)+b2
    return y

mnist_backward.py

#coding:utf-8
#版本信息:ubuntu18.04  python3.6.8  tensorflow 1.14.0
#引入tensorflow、input_data、前向传播mnist_forward和os模块
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATCH_SIZE=200  #每轮喂入神经网络的图片数
LEARNING_RATE_BASE=0.1  #初始学习率
LEARNING_RATE_DECAY=0.99  #学习率衰减率
REGULARIZER=0.0001  #正则化系数
STEPS=50000   #训练轮数
MOVING_AVERAGE_DECAY=0.99   #滑动平均衰减率
MODEL_SAVE_PATH="./model/"  #模型保存路径
MODEL_NAME="mnist_model"   #模型保存名称
 

def backward(mnist):

    #用placeholder给训练数据x和标签y_占位
    x=tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
    y_=tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
    #调用mnist_forward文件中的前向传播过程forword()函数,并设置正则化,计算训练数据集上的预测结果y
    y=mnist_forward.forward(x,REGULARIZER)
    #当前计算轮数计数器赋值,设定为不可训练类型
    global_step=tf.Variable(0,trainable=False)

    #调用包含所有参数正则化损失的损失函数lossce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
    cem=tf.reduce_mean(ce)
    loss=cem+tf.add_n(tf.get_collection('losses'))

    learning_rate=tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples/BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)

    #使用梯度衰减算法对模型优化,降低损失函数
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)  
   
 #定义参数的滑动平均   ema=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
    ema_op=ema.apply(tf.trainable_variables())
    #实例化可还原滑动平均的saver 
    #在模型训练时引入滑动平均可以使模型在测试数据上表现的更加健壮
    with tf.control_dependencies([train_step,ema_op]):
        train_op=tf.no_op(name='train')

    saver=tf.train.Saver()

    with tf.Session() as sess:
       #所有参数初始化
        init_op=tf.global_variables_initializer()
        sess.run(init_op)

        #断点续训,加入ckpt操作
        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        #每次喂入batch_size组(即200组)训练数据和对应标签,循环迭代steps轮
        for i in range(STEPS):
            xs,ys=mnist.train.next_batch(BATCH_SIZE)
            _,loss_value,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
            if i%1000==0:
                print("After %d training step(s),loss on training batch is %g."%(step,loss_value))
                saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)


def main():
    #读入mnist
    mnist=input_data.read_data_sets("./data/",one_hot=True)
     #反向传播
    backward(mnist)

if __name__=='__main__':
    main()

mnist_test.py

#coding:utf-8
#版本信息:ubuntu18.04  python3.6.8  tensorflow 1.14.0
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
TEST_INTERVAL_SECS=5#程序5秒的循环间隔时间

def test(mnist):#利用tf.Graph()复现之前定义的计算图
    with tf.Graph().as_default() as g:
        #利用placeholder给训练数据x和标签y_占位
        x=tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
        y_=tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
        #调用mnist_forward文件中的前向传播过程forword()函数
        y=mnist_forward.forward(x,None)

#实例化具有滑动平均的saver对象,从而在会话被加载时模型中的所有参数被赋值为各自的滑动平均值,增强模型的稳定性        ema=tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_restore=ema.variables_to_restore()
        savor=tf.train.Saver(ema_restore)

        #计算模型在测试集上的准确率
        correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

        while True:#加载指定路径下的ckpt
            with tf.Session() as sess:
                ckpt=tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                #若模型存在,则加载出模型到当前对话,在测试数据集上进行准确率验证,并打印出当前轮数下的准确率
                if ckpt and ckpt.model_checkpoint_path:
                    savor.restore(sess,ckpt.model_checkpoint_path)
                    global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
                    print("After %s training step(s),test accuracy=%g"%(global_step,accuracy_score))
                #若模型不存在,则打印出模型不存在的提示,从而test()函数完成
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():
    #加载指定路径下的测试数据集
    mnist=input_data.read_data_sets("./data/",one_hot=True)
    test(mnist)

if __name__=='__main__':
    main()
  • 运行结果

从终端显示的运行结果可以看出,随着训练轮数的增加,网络模型的损失函数值在不断降低,训练集上的精确度也在不断提高,具有良好的泛化能力。


数字识别.png

相关文章

网友评论

      本文标题:tensorflow笔记:第五讲 全连接网络基础(+断点续训)

      本文链接:https://www.haomeiwen.com/subject/qnifjctx.html