LSTM

作者: scpy | 来源:发表于2018-12-24 13:46 被阅读0次

    LSTM(Long Short-Term Memory)是长短期记忆网络,是一种时间递归神经网络,适合于处理和预测时间序列中间隔和延迟相对较长的重要事件。

    # -*- coding: utf-8 -*-
    import numpy as np
    import tensorflow as tf
    import matplotlib as mpl
    from matplotlib import pyplot as plt
    import os
    
    
    class LSTM():
        def __init__(self,
                     hidden_size=30,  # lstm隐藏节点
                     num_layers=2,  # lstm层数
                     timestamps=10,  # 训练序列长度
                     batch_size=32,
                     training_examples=10000,  # 训练集个数
                     test_examples=1000,
                     model_path="D:/kaggle/rnn/",
                     load_model=False,
                     training_steps=10000,
                     logdir="D:/kaggle/rnn/"
                     ):
            self.batch_size = batch_size
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.timestamps = timestamps
            self.train_examples = training_examples
            self.test_examples = test_examples
            self.log_dir = logdir
            self.training_steps = training_steps
            self.model_path = model_path
            self.build_graph()
            self.init_op()
            if load_model:
                self.load_model(model_path)
    
        def build_graph(self):
            self.graph = tf.Graph()
            with self.graph.as_default():
                self.X = tf.placeholder(tf.float32, shape=[None, 1, self.timestamps])
                self.y = tf.placeholder(tf.float32, shape=[None, 1])
                # 多层lstm
                cell = tf.nn.rnn_cell.MultiRNNCell([
                    tf.nn.rnn_cell.BasicLSTMCell(self.hidden_size)
                    for _ in range(self.num_layers)
                ])
                # 将多层lstm连接成rnn网络并计算前向传播结果
                outputs, _ = tf.nn.dynamic_rnn(cell, self.X, dtype=tf.float32)
                # outputs维度【batch_size,time,hidden_size】
                output = outputs[:, -1, :]  # 最后一步的输出
                self.prediction = tf.contrib.layers.fully_connected(output, 1, activation_fn=None)
                self.loss = tf.losses.mean_squared_error(labels=self.y, predictions=self.prediction)
                tf.summary.scalar('loss', self.loss)
                self.train_op = tf.contrib.layers.optimize_loss(
                    self.loss, tf.train.get_global_step(),
                    optimizer="Adagrad", learning_rate=0.1
                )
                self.init = tf.global_variables_initializer()
                self.merged_summary_op = tf.summary.merge_all()
                self.saver = tf.train.Saver()
    
        def init_op(self):
            self.sess = tf.Session(graph=self.graph)
            self.sess.run(self.init)
            self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
    
        # train_X: 10000, 1, 10
        def train(self, train_X, train_y):
            # get batch data
            with self.graph.as_default():
                ds = tf.data.Dataset.from_tensor_slices((train_X, train_y))
                ds = ds.repeat().shuffle(1000).batch(self.batch_size)
                X, y = ds.make_one_shot_iterator().get_next()
    
            for i in range(self.training_steps):
                batch_x, batch_y = self.sess.run([X, y])
                feed_dict = {self.X: batch_x, self.y: batch_y}
                _, loss, summary_str = self.sess.run(
                    [self.train_op, self.loss, self.merged_summary_op], feed_dict=feed_dict)
                if i % 1000 == 0:
                    print('train step:' + str(i) + ', loss=' + str(loss))
                    self.saver.save(self.sess,
                                    os.path.join(self.model_path, "lstm.ckpt"),
                                    global_step=tf.train.get_global_step())
    
        def test(self, test_X, test_y):
            with self.graph.as_default():
                ds = tf.data.Dataset.from_tensor_slices((test_X, test_y))
                ds.batch(1)
                X, y = ds.make_one_shot_iterator().get_next()  # 这里的X,y都是tensor
    
            predictions = []
            labels = []
            for i in range(self.test_examples):
                batch_x, batch_y = self.sess.run([X, y])
                batch_x = np.reshape(batch_x, (-1, 1, 10))
                p = self.sess.run([self.prediction], feed_dict={self.X: batch_x})
                predictions.append(p)
                labels.append(batch_y)
            predictions = np.array(predictions).squeeze()  # 删除shape中为1的维度
            labels = np.array(labels).squeeze()
            rmse = np.sqrt((predictions - labels) ** 2).mean(axis=0)
            print("mean square error is: %f" % rmse)
    
            # 绘图
            plt.figure()
            plt.plot(predictions, label="predictions")
            # plt.plot(labels, label="real_sin")
            plt.legend()
            plt.show()
    
        def load_model(self, model_path):
            if not os.path.exists(model_path):
                RuntimeError("model_path is not exist")
            ckpt = tf.train.get_checkpoint_state(self.model_path)
            if ckpt and ckpt.model_checkpoint_path:
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
    
    
    def get_data(seq, timestamp):
        X = []
        y = []
        for i in range(len(seq) - timestamp):
            X.append([seq[i: i + timestamp]])
            y.append([seq[i + timestamp]])
        return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)
    
    
    if __name__ == '__main__':
        sample_gap = 0.01  # 采样间隔
        lstm = LSTM(load_model=False)
        test_start = (lstm.train_examples + lstm.timestamps) * sample_gap
        test_end = test_start + (lstm.test_examples + lstm.timestamps) * sample_gap
        train_x, train_y = get_data(np.sin(np.linspace(
            0, test_start, lstm.train_examples + lstm.timestamps, dtype=np.float32),
        ), lstm.timestamps)
        test_X, test_y = get_data(np.sin(np.linspace(test_start, test_end,
                                                     lstm.test_examples + lstm.timestamps,
                                                     dtype=np.float32)), lstm.timestamps)
        lstm.train(train_x, train_y)
        lstm.test(test_X, test_y)
    

    相关文章

      网友评论

          本文标题:LSTM

          本文链接:https://www.haomeiwen.com/subject/bgxzkqtx.html