美文网首页
RNN实现股价预测

RNN实现股价预测

作者: y_7539 | 来源:发表于2023-01-31 17:06 被阅读0次
    import pandas as pd
    import numpy as np
    df = pd.read_csv("datas/zgpa_train.csv")
    df.head()
    
    image.png
    price = df["close"]
    # 归一化
    price_norm = price/max(price)
    
    import matplotlib.pyplot as plt
    %matplotlib inline
    plt.figure(figsize=(5,3))
    plt.plot(price)
    plt.xlabel("time")
    plt.ylabel("price")
    plt.show()
    
    image.png
    # 提取x和y
    def extract_data(data, time_step):
        X = []
        y = []
        for i in range(len(data) - time_step):
            X.append([a for a in data[i: i+time_step]])
            y.append(data[i+time_step])
        X = np.array(X)
        X = X.reshape(X.shape[0], X.shape[1], 1)
        return X, y
    
    #样本大小
    time_step=8
    
    # 定义x和y  用前八位预测第九位
    X, y = extract_data(price_norm, time_step)
    
    from keras.models import Sequential
    from keras.layers import Dense, SimpleRNN
    
    #建立模型
    model = Sequential()
    #添加rnn层
    model.add(SimpleRNN(units=5, input_shape=(time_step, 1), activation="relu"))
    #输出层
    model.add(Dense(units=1, activation="linear"))
    #模型配置
    model.compile(optimizer="adam", loss="mean_squared_error")
    
    #模型训练  损失不变可以重新载入模型
    model.fit(X, np.array(y), batch_size=30, epochs=200)
    
    #预测训练数据
    y_train_predict = model.predict(X) * max(price)
    y_train = [i * max(price) for i in y]
    
    plt.figure(figsize=(5,3))
    plt.plot(y_train_predict, label="predict price")
    plt.plot(y_train, label="true price")
    plt.xlabel("time")
    plt.ylabel("price")
    plt.legend()
    plt.show()
    
    image.png
    #预测测试数据
    test_data = pd.read_csv("datas/zgpa_test.csv")
    test_data.head()
    
    image.png
    price_test = test_data["close"]
    #归一化 统一分母
    price_test_norm = price_test/max(price)
    x_test_norm, y_test_norm = extract_data(price_test_norm, time_step)
    # 预测测试数据
    y_test_predict = model.predict(x_test_norm) * max(price)
    y_test = [i*max(price) for i in y_test_norm]
    
    plt.figure(figsize=(5,3))
    plt.plot(y_test_predict, label="test predict price")
    plt.plot(y_test, label="test true price")
    plt.xlabel("time")
    plt.ylabel("price")
    plt.legend()
    plt.show()
    
    image.png
    #存储数据
    result_y_test = np.array(y_test).reshape(-1, 1)
    result_y_test_predict = y_test_predict
    print(result_y_test.shape, result_y_test_predict.shape)
    #合并数组
    result = np.concatenate((result_y_test, result_y_test_predict), axis=1)
    result = pd.DataFrame(result, columns=["real_price_test", "predict_price_test"])
    result.to_csv("zgpa_predict_test.csv")
    #预测结果会慢一步
    

    相关文章

      网友评论

          本文标题:RNN实现股价预测

          本文链接:https://www.haomeiwen.com/subject/kpkbhdtx.html