美文网首页Hello Data, Hello Math!
BPTT推导以及基于LSTM的MNIST手写数字识别实验

BPTT推导以及基于LSTM的MNIST手写数字识别实验

作者: Rethinkpossible | 来源:发表于2018-08-10 10:17 被阅读0次

    BPTT (Back PropagationThough Time)公式推导

       由于RNN被广泛应用于序列标注问题(SequenceLabeling),所以这里选取该

    问题作为实例来解释BPTT。下图是典型的RNN结构展开之后的结构,非常常见。

    (图一)

      则将图一具体化为下图:

    (图二)

    基于LSTM 进行MNIST手写数字识别实验笔记

    那么将基础结构构造成时序结构如下所示:

      注意,上图显示的并不是不同Block中的不同神经元,而是同一个Block中同一个神经元

    在不同时刻的状态以及不同时刻之间如何传递信息。具体Block中的细节以及公式如下图所示:

    代码如下:

    from __future__ import print_function

    import numpy as np

    import tensorflow as tf

    from tensorflow.contrib import rnn

    from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

    classmnistExp:

       def__init__(self,datapath,logpath):

           self.mnist = read_data_sets(datapath, one_hot=True)

           self.logpath = logpath

           self.learning_rate = 0.001

           self.iters_num = 100000

           self.batch_size = 128

           self.print_step = 10

           self.input_size = 28

           self.timesteps = 28

           self.hidden_size = 128

           self.class_num = 10

           self.epoch_num = 30

           self.test_len = 300

       defInitGblVar(self):

           with tf.name_scope('Inputs'):

               self.x = tf.placeholder("float", [None, self.timesteps, self.input_size], name='X')

               self.istate = tf.placeholder("float", [None, 2 * self.hidden_size], name='istate')

               self.y = tf.placeholder("float", [None, self.class_num], name='Y')

           with tf.name_scope('Weights'):

               self.weights = {

                   'hidden': tf.Variable(tf.random_normal([self.input_size, self.hidden_size])),

                   'out': tf.Variable(tf.random_normal([self.hidden_size, self.class_num]))

               }

               tf.summary.histogram('weights_hidden',self.weights['hidden'])

               tf.summary.histogram('weights_out',self.weights['out'])

           with tf.name_scope('Biases'):

               self.biases = {

                   'hidden': tf.Variable(tf.random_normal([self.hidden_size])),

                   'out': tf.Variable(tf.random_normal([self.class_num]))

               }

               tf.summary.histogram('bias_hidden',self.biases['hidden'])

               tf.summary.histogram('bias_out',self.biases['out'])

           self.pred = self.LSTM(self.x, self.istate, self.weights, self.biases)

           with tf.name_scope('Cost'):

               self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.pred, labels=self.y))

               tf.summary.scalar('Cost',self.cost)

           with tf.name_scope('Train'):

               self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)

           with tf.name_scope('Prediction'):

               self.correct_pred = tf.equal(tf.argmax(self.pred, 1), tf.argmax(self.y, 1))

               self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))

               tf.summary.scalar('Accuracy',self.accuracy)

           self.init = tf.global_variables_initializer()

       defLSTM(self,_batch_x, _istate, _weights, _biases):

           with tf.name_scope('lstm_block'):

               _batch_x = tf.transpose(_batch_x, [1, 0, 2])

               _batch_x = tf.reshape(_batch_x, [-1, self.input_size])

               _batch_x = tf.matmul(_batch_x, _weights['hidden']) + _biases['hidden']

               _batch_x = tf.split(_batch_x, self.timesteps, 0)

               lstm_block = rnn.BasicLSTMCell(self.hidden_size, forget_bias=1.0)

               outputs, states = rnn.static_rnn(lstm_block, _batch_x, dtype=tf.float32)

           return tf.matmul(outputs[-1], _weights['out']) + _biases['out']

       defrun(self):

           self.InitGblVar()

           with tf.Session() as sess:

               sess.run(self.init)

               merged = tf.summary.merge_all()

               writer = tf.summary.FileWriter(self.logpath, sess.graph)

               for i in range(self.epoch_num):

                   step = 1

                   while step * self.batch_size < self.iters_num:

                       batch_xs, batch_ys = self.mnist.train.next_batch(self.batch_size)

                       # tf.summary.image('batch_xs',batch_xs,max_outputs=10)

                       batch_xs = batch_xs.reshape((self.batch_size, self.timesteps, self.input_size))

                       sess.run(self.optimizer, feed_dict={self.x: batch_xs, self.y: batch_ys,

                                                      self.istate: np.zeros((self.batch_size, 2 * self.hidden_size))})

                       if step % self.print_step == 0:

                           acc, loss, summary = sess.run([self.accuracy, self.cost, merged], feed_dict={self.x: batch_xs, self.y: batch_ys,

                                                               self.istate: np.zeros((self.batch_size, 2 * self.hidden_size))})

                       step += 1

                   test_data = self.mnist.test.images[:self.test_len].reshape((-1, self.timesteps, self.input_size))

                   test_label = self.mnist.test.labels[:self.test_len]

                   print(" [*] Epoch " + str(i+1) + ": Optimization has finished, Testing Accuracy is " ,\

                       sess.run(self.accuracy, feed_dict={self.x: test_data, self.y: test_label,self.istate: \

                           np.zeros((self.test_len, 2 * self.hidden_size))}))

                   summary = sess.run(merged, feed_dict={self.x: test_data, self.y: test_label})

                   writer.add_summary(summary, i)

    if __name__ == '__main__':

       datapath = "C:\\Users\\Administrator\\Desktop\\deep_lab\\mnist_data"

       logpath = "/tensorboard_log/tf_Ex" #tensorboard

       obj = mnistExp(datapath,logpath)

       obj.run()

    实验结果:

      这里稍微介绍一下tensorboard,以方便直观学习。在训练的时候,会自动在log文

    件夹中生成一个类似这样的文件后,

       不用等训练结束也可以执行下面的语句,来观察训练情况:

     训练结果:

        SCALARS:

       GRAPHS:

       DISTRIBUTIONS:

        HISTOGRAMS:

    参考:

        https://www.cnblogs.com/steven-yang/p/6407445.html

        http://www.cnblogs.com/wacc/p/5341670.html

        https://en.wikipedia.org/wiki/Matrix_calculus

        https://www.cnblogs.com/zhbzz2007/p/6339346.html

        http://zhwhong.ml/2017/02/24/Backpropagation-principle/

        https://zhuanlan.zhihu.com/p/26892413

        http://www.sohu.com/a/195366563_465975

        http://colah.github.io/posts/2015-08-Understanding-LSTMs/

        http://blog.csdn.net/u010754290/article/details/47167979

    相关文章

      网友评论

        本文标题:BPTT推导以及基于LSTM的MNIST手写数字识别实验

        本文链接:https://www.haomeiwen.com/subject/rrtcaxtx.html