美文网首页我爱编程
TF笔记 - 全连接网络基础

TF笔记 - 全连接网络基础

作者: 闫_锋 | 来源:发表于2018-05-27 16:38 被阅读46次

√mnist 数据集:包含7万张黑底白字手写数字图片,其中 55000 张为训练集,5000 张为验证集,10000 张为测试集。每张图片大小为 28*28 像素,图片中纯黑色像素值为 0,纯白色像素值为1。数据集的标签是长度为10的一维数组,数组中每个元素索引号表示对应数字出现的概率。

每张图片变为长度784 一维数组,将该数组作为神经网络输入特征喂入神经网络。

√使用 input_data 模块中的 read_data_sets()函数加载 mnist 数据集:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(’./data/’,one_hot=True)

√返回mnist数据集中训练集train、验证集validation和测试集test 样本数
1返回训练集train样本数

print “train data size:”,mnist.train.mun_examples

2返回验证集validation样本数

print “validation data size:”,mnist.validation.mun_examples

3返回测试集test样本数

print “test data size:”,mnist.test.mun_examples

√使用 train.labels 函数返回mnist数据集标签

mnist.train.labels[0]

√使用 train.images 函数返回mnist数据集图片像素值

mnist.train.images[0]

√使用mnist.train.next_batch()函数将数据输入神经网络

xs, ys = mnist.train.next_batch(BATCH_SIZE)

√实现“Mnist 数据集手写数字识别”的常用函数:
1 tf.get_collection(“”)函数表示从 collection 集合中取出全部变量生成
一个列表。

2 tf.add_n( )函数表示将参数列表中对应元素相加。

3 tf.cast(x, dtype)函数表示将参数 x 转换为指定数据类型。

4 tf.argmax(x, axis)返回最大值所在索引号如: tf.argmax([1,0,0], 1)返回0

5 os.path.join("home", "name")

6 str.split()

7 with tf.Grapg().as_default as g: 其内定义的节点

8 tf.equal( )函数表示对比两个矩阵或者向量的元素。若对应元素相等,则返回 True;若对应元素不相等,则返回 False。

9 tf.reduce_mean(x,axis)函数表示求取矩阵或张量指定维度的平均值。若不指定第二个参数,则在所有元素中取平均值;若指定第二个参数为 0,则在第一维元素上取平均值,即每一列求平均值;若指定第二个参数为 1,则在第二维元素上取平均值,即每一行求平均值。

√神经网络模型的保存
在反向传播过程中,一般会间隔一定轮数保存一次神经网络模型,并产生三个文件(保存当前图结构的.meta文件、保存当前参数名的.index文件、保存当前参数的.data 文件),在Tensorflow中如下表示:

saver = tf.train.Saver()
with tf.Session() as sess:
  for i in range(STEPS):
    if i % 轮数 == 0:
      saver.save(sess, os.path.join(MODEL_SAVE_PATH,
MODEL_NAME), global_step=global_step)

其中,tf.train.Saver()用来实例化 saver 对象。上述代码表示,神经网络每循环规定的轮数,将神经网络模型中所有的参数等信息保存到指定的路径中,并在存放网络模型的文件夹名称中注明保存模型时的训练轮数。

√神经网络模型的加载
在测试网络效果时,需要将训练好的神经网络模型加载,在 Tensorflow 中这
样表示:
with tf.Session() as sess:
  ckpt = tf.train.get_checkpoint_state(存储路径)
  if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

在 with 结构中进行加载保存的神经网络模型,若ckpt 和保存的模型在指定路径中存在,则将保存的神经网络模型加载到当前会话中。

√加载模型中参数的滑动平均值
在保存模型时,若模型中采用滑动平均,则参数的滑动平均值会保存在相应文件中。通过实例化 saver 对象,实现参数滑动平均值的加载,在Tensorflow中如下表示:

ema = tf.train.ExponentialMovingAverage(滑动平均基数)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)

√神经网络模型准确率评估方法
在网络评估时,一般通过计算在一组数据上的识别准确率,评估神经网络的效果。在Tensorflow中这样表示:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

相关文章

网友评论

    本文标题:TF笔记 - 全连接网络基础

    本文链接:https://www.haomeiwen.com/subject/hnbbjftx.html