美文网首页Machine_Learning
LSTM 实践--客运量数据分析(2)

LSTM 实践--客运量数据分析(2)

作者: 斐波那契的数字 | 来源:发表于2018-06-12 14:00 被阅读47次

    1. 预处理

    url = './铁路客运量.csv'

    ass_data = requests.get(url).content #打开文件

    df = pd.read_csv(io.StringIO(ass_data.decode('utf-8')))  # python2使用StringIO.StringIO

    data = np.array(df['铁路客运量_当期值(万人)']) 

    # normalize

    normalized_data = (data - np.mean(data)) / np.std(data)

    seq_size = 3

    train_x, train_y = [], []

    for i in range(len(normalized_data) - seq_size - 1): # 解析

     train_x.append(np.expand_dims(normalized_data[i: i + seq_size], axis=1).tolist())

        train_y.append(normalized_data[i + 1: i + seq_size + 1].tolist())

    input_dim = 1

    X = tf.placeholder(tf.float32, [None, seq_size, input_dim])

    Y = tf.placeholder(tf.float32, [None, seq_size])

    2. 模型构建

    def ass_rnn(hidden_layer_size=6): #  regression

        W = tf.Variable(tf.random_normal([hidden_layer_size, 1]), name='W')

        b = tf.Variable(tf.random_normal([1]), name='b')

        cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_layer_size)# 6

        outputs, states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)

        W_repeated = tf.tile(tf.expand_dims(W, 0), [tf.shape(X)[0], 1, 1])

        out = tf.batch_matmul(outputs, W_repeated) + b

        out = tf.squeeze(out)

        return out

    3. 训练模型

    def train_rnn():

        out = ass_rnn()

        loss = tf.reduce_mean(tf.square(out - Y))

        train_op = tf.train.AdamOptimizer(learning_rate=0.003).minimize(loss)

        saver = tf.train.Saver(tf.all_variables())  # tf.global_variables()  == tf 0.12

        with tf.Session() as sess:

            tf.get_variable_scope().reuse_variables()  # old API

            sess.run(tf.initialize_all_variables())  # tf.global_variables_initializer() ==tf  0.12

            for step in range(10000):

                _, loss_ = sess.run([train_op, loss], feed_dict={X: train_x, Y: train_y})

                if step % 10 == 0:

                    # 用测试数据评估loss

                    print(step, loss_)

            print("保存模型: ", saver.save(sess, 'ass.model'))

    4. 预测

    def prediction():

        out = ass_rnn()

        saver = tf.train.Saver(tf.all_variables())  # new API  tf.global_variables()

        with tf.Session() as sess:

            tf.get_variable_scope().reuse_variables()  # old API

            saver.restore(sess, './ass.model')

            prev_seq = train_x[-1]

            predict = []

            for i in range(12):

                next_seq = sess.run(out, feed_dict={X: [prev_seq]})

                predict.append(next_seq[-1])

                prev_seq = np.vstack((prev_seq[1:], next_seq[-1]))

            plt.figure()

            plt.plot(list(range(len(normalized_data))), normalized_data, color='b')

            plt.plot(list(range(len(normalized_data), len(normalized_data) + len(predict))), predict, color='r')

            plt.show()

    # train_rnn()  #

    prediction()

    相关文章

      网友评论

        本文标题:LSTM 实践--客运量数据分析(2)

        本文链接:https://www.haomeiwen.com/subject/zqileftx.html