美文网首页Hello Data, Hello Math!
BPTT推导以及基于LSTM的MNIST手写数字识别实验

BPTT推导以及基于LSTM的MNIST手写数字识别实验

作者: Rethinkpossible | 来源:发表于2018-08-10 10:17 被阅读0次

BPTT (Back PropagationThough Time)公式推导

   由于RNN被广泛应用于序列标注问题(SequenceLabeling),所以这里选取该

问题作为实例来解释BPTT。下图是典型的RNN结构展开之后的结构,非常常见。

(图一)

  则将图一具体化为下图:

(图二)

基于LSTM 进行MNIST手写数字识别实验笔记

那么将基础结构构造成时序结构如下所示:

  注意,上图显示的并不是不同Block中的不同神经元,而是同一个Block中同一个神经元

在不同时刻的状态以及不同时刻之间如何传递信息。具体Block中的细节以及公式如下图所示:

代码如下:

from __future__ import print_function

import numpy as np

import tensorflow as tf

from tensorflow.contrib import rnn

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

classmnistExp:

   def__init__(self,datapath,logpath):

       self.mnist = read_data_sets(datapath, one_hot=True)

       self.logpath = logpath

       self.learning_rate = 0.001

       self.iters_num = 100000

       self.batch_size = 128

       self.print_step = 10

       self.input_size = 28

       self.timesteps = 28

       self.hidden_size = 128

       self.class_num = 10

       self.epoch_num = 30

       self.test_len = 300

   defInitGblVar(self):

       with tf.name_scope('Inputs'):

           self.x = tf.placeholder("float", [None, self.timesteps, self.input_size], name='X')

           self.istate = tf.placeholder("float", [None, 2 * self.hidden_size], name='istate')

           self.y = tf.placeholder("float", [None, self.class_num], name='Y')

       with tf.name_scope('Weights'):

           self.weights = {

               'hidden': tf.Variable(tf.random_normal([self.input_size, self.hidden_size])),

               'out': tf.Variable(tf.random_normal([self.hidden_size, self.class_num]))

           }

           tf.summary.histogram('weights_hidden',self.weights['hidden'])

           tf.summary.histogram('weights_out',self.weights['out'])

       with tf.name_scope('Biases'):

           self.biases = {

               'hidden': tf.Variable(tf.random_normal([self.hidden_size])),

               'out': tf.Variable(tf.random_normal([self.class_num]))

           }

           tf.summary.histogram('bias_hidden',self.biases['hidden'])

           tf.summary.histogram('bias_out',self.biases['out'])

       self.pred = self.LSTM(self.x, self.istate, self.weights, self.biases)

       with tf.name_scope('Cost'):

           self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.pred, labels=self.y))

           tf.summary.scalar('Cost',self.cost)

       with tf.name_scope('Train'):

           self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)

       with tf.name_scope('Prediction'):

           self.correct_pred = tf.equal(tf.argmax(self.pred, 1), tf.argmax(self.y, 1))

           self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))

           tf.summary.scalar('Accuracy',self.accuracy)

       self.init = tf.global_variables_initializer()

   defLSTM(self,_batch_x, _istate, _weights, _biases):

       with tf.name_scope('lstm_block'):

           _batch_x = tf.transpose(_batch_x, [1, 0, 2])

           _batch_x = tf.reshape(_batch_x, [-1, self.input_size])

           _batch_x = tf.matmul(_batch_x, _weights['hidden']) + _biases['hidden']

           _batch_x = tf.split(_batch_x, self.timesteps, 0)

           lstm_block = rnn.BasicLSTMCell(self.hidden_size, forget_bias=1.0)

           outputs, states = rnn.static_rnn(lstm_block, _batch_x, dtype=tf.float32)

       return tf.matmul(outputs[-1], _weights['out']) + _biases['out']

   defrun(self):

       self.InitGblVar()

       with tf.Session() as sess:

           sess.run(self.init)

           merged = tf.summary.merge_all()

           writer = tf.summary.FileWriter(self.logpath, sess.graph)

           for i in range(self.epoch_num):

               step = 1

               while step * self.batch_size < self.iters_num:

                   batch_xs, batch_ys = self.mnist.train.next_batch(self.batch_size)

                   # tf.summary.image('batch_xs',batch_xs,max_outputs=10)

                   batch_xs = batch_xs.reshape((self.batch_size, self.timesteps, self.input_size))

                   sess.run(self.optimizer, feed_dict={self.x: batch_xs, self.y: batch_ys,

                                                  self.istate: np.zeros((self.batch_size, 2 * self.hidden_size))})

                   if step % self.print_step == 0:

                       acc, loss, summary = sess.run([self.accuracy, self.cost, merged], feed_dict={self.x: batch_xs, self.y: batch_ys,

                                                           self.istate: np.zeros((self.batch_size, 2 * self.hidden_size))})

                   step += 1

               test_data = self.mnist.test.images[:self.test_len].reshape((-1, self.timesteps, self.input_size))

               test_label = self.mnist.test.labels[:self.test_len]

               print(" [*] Epoch " + str(i+1) + ": Optimization has finished, Testing Accuracy is " ,\

                   sess.run(self.accuracy, feed_dict={self.x: test_data, self.y: test_label,self.istate: \

                       np.zeros((self.test_len, 2 * self.hidden_size))}))

               summary = sess.run(merged, feed_dict={self.x: test_data, self.y: test_label})

               writer.add_summary(summary, i)

if __name__ == '__main__':

   datapath = "C:\\Users\\Administrator\\Desktop\\deep_lab\\mnist_data"

   logpath = "/tensorboard_log/tf_Ex" #tensorboard

   obj = mnistExp(datapath,logpath)

   obj.run()

实验结果:

  这里稍微介绍一下tensorboard,以方便直观学习。在训练的时候,会自动在log文

件夹中生成一个类似这样的文件后,

   不用等训练结束也可以执行下面的语句,来观察训练情况:

 训练结果:

    SCALARS:

   GRAPHS:

   DISTRIBUTIONS:

    HISTOGRAMS:

参考:

    https://www.cnblogs.com/steven-yang/p/6407445.html

    http://www.cnblogs.com/wacc/p/5341670.html

    https://en.wikipedia.org/wiki/Matrix_calculus

    https://www.cnblogs.com/zhbzz2007/p/6339346.html

    http://zhwhong.ml/2017/02/24/Backpropagation-principle/

    https://zhuanlan.zhihu.com/p/26892413

    http://www.sohu.com/a/195366563_465975

    http://colah.github.io/posts/2015-08-Understanding-LSTMs/

    http://blog.csdn.net/u010754290/article/details/47167979

相关文章

网友评论

    本文标题:BPTT推导以及基于LSTM的MNIST手写数字识别实验

    本文链接:https://www.haomeiwen.com/subject/rrtcaxtx.html