1. 前言
本文意在对LSTM有初步的实践操作,
相关代码在tec4tensorflow , 相关数据集在 stock_dataset.csv
2. 简介
LSTM全称长短期记忆人工神经网络(Long-Short Term Memory),是对RNN的变种。可以有效的解决梯度消失和梯度爆炸的问题。
在LSTM中,我们可以控制丢弃什么信息,存放什么信息。具体的理论这里就不多说了,推荐一篇博文Understanding LSTM Networks[2],里面有对LSTM详细的介绍,也可以看网友的翻译版[译] 理解 LSTM 网络。
3. LSTM应用实践--股票预测
在对理论有理解的基础上,我们使用LSTM对股票每日最高价进行预测。在本例中,仅使用一维特征。本例取每日最高价作为输入特征[x],后一天的最高价最为标签[y]。
数据格式如下:
时间,最高价格3.1 LSTM 实践
了解了所训练的数据以后,我们就可以去动手实践了。
3.1.1 导入数据的操作
添加所需要的工具库:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
读取数据集的数据
f=open('stock_dataset.csv')#指定文件引用
df=pd.read_csv(f)#读入股票数据
data=np.array(df['最高价'])#获取最高价列
data=data[::-1]#反转,使数据按照日期先后顺序排列
显示股票数据
#以折线图展示数据
plt.figure()
plt.plot(data)#绘图
plt.show()# 显示
数据预处理
normalize_data=(data-np.mean(data))/np.std(data) #标准化
normalize_data=normalize_data[:,np.newaxis] #增加维度
time_step=20 #时间步
rnn_unit=10 #hidden layer units
batch_size=60 #每一批次训练多少个样例
input_size=1 #输入层维度,与数据集合中的x一致
output_size=1 #输出层维度,与label 一致
lr=0.0006 #学习率Learning Rate
train_x,train_y=[],[] #训练集
for i in range(len(normalize_data)-time_step-1):
x=normalize_data[i:i+time_step]
y=normalize_data[i+1:i+time_step+1]
train_x.append(x.tolist())
train_y.append(y.tolist())
3.2 定义神经网络中使用的变量
X=tf.placeholder(tf.float32, [None,time_step,input_size]) #每批次输入网络的tensor
Y=tf.placeholder(tf.float32, [None,time_step,output_size]) #每批次tensor对应的标签
#输入层、输出层权重、偏置
weights={
'in':tf.Variable(tf.random_normal([input_size,rnn_unit])),
'out':tf.Variable(tf.random_normal([rnn_unit,1]))
}
biases={
'in':tf.Variable(tf.constant(0.1,shape=[rnn_unit,])),
'out':tf.Variable(tf.constant(0.1,shape=[1,]))
}
3.3 定义lstm网络
简单的画了一下网络模型。最近刚刚会用OmmiGraffle,这次读的代码有些匆忙。回头把这张图补上。
网络图形为了实现上图的循环神经网络,有必要先了解一下这些函数:
(1)tf.reshape函数用于对tensor的维度进行修改。示例如下:tf.reshape(<待修改维度的tensor>,[函数的维度])
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit])
#time_step=20; rnn_unit=10
(2)定义单个基本的LSTM单元,应该使用tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0, state_is_tuple=True)
在LSTM单元中,有2个状态值,分别是c和h,分别对应于下图中的c和h。其中h在作为当前时间段的输出的同时,也是下一时间段的输入的一部分。当state_is_tuple=True的时候,state是元组形式,state=(c,h)。如果是False,那么state是一个由c和h拼接起来的张量,state=tf.concat(1,[c,h])。在运行时,则返回2值,一个是h,还有一个state。
示例:
一个BasicLSTMCellcell= tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)#rnn_unit=10
(3)初始化网络的state : cell.zero_state() 示例:
init_state=cell.zero_state(batch,dtype=tf.float32) #batch=60
(4)进行时间展开的方法:tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,dtype=None,time_major=False)
此函数会通过,inputs中的max_time将网络按时间展开
参数解析
cell:将上面的lstm_cell传入就可以
inputs:[batch_size, max_time, size] , 如果time_major=Flase. [max_time, batch_size, size] ,如果time_major=True。
sequence_length:是一个list,如果你要输入三句话,且三句话的长度分别是5,10,25,那么sequence_length=[5,10,25]
返回:(outputs, states):output,[batch_size, max_time, num_units]。如果time_major=False。 [max_time,batch_size,num_units]如果time_major=True。states:[batch_size, 2*len(cells)]或[batch_size,s]
outputs输出的是最上面一层的输出,states保存的是最后一个时间输出的states
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32)
该部分的完整片段:
def lstm(batch): #参数:输入网络批次数目
w_in=weights['in']
b_in=biases['in']
input=tf.reshape(X,[-1,input_size]) #需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入
input_rnn=tf.matmul(input,w_in)+b_in
input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit]) #将tensor转成3维,作为lstm cell的输入
cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
init_state=cell.zero_state(batch,dtype=tf.float32)
output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32) #output_rnn是记录lstm每个输出节点的结果,final_states是最后一个cell的结果
output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入
w_out=weights['out']
b_out=biases['out']
pred=tf.matmul(output,w_out)+b_out
return pred,final_states
3.4 训练模型
定义损失函数:
loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))
定义优化器
train_op=tf.train.AdamOptimizer(lr).minimize(loss)
保存训练过程
saver=tf.train.Saver(tf.global_variables())
...
print("保存模型:",saver.save(sess,'stock.model'))
初始化图中的变量
sess.run(tf.global_variables_initializer())
传递数据:
_,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})
程序片段截图3.5 预测模型
程序片段截图4. 安装错误:
matplotlib: RuntimeError: Python is not installed as a framework
RuntimeError: Python is not installed as a framework 错误解决方案
5. 有用的工具命令
编码转变 iconv-f"gbk" -t"utf-8" < infile > outfile
6. 参考文献
[1] tensorflow 官方使用说明recurrent-neural-networks
[2] Understanding LSTM Networks
[3] The Unreasonable Effectiveness of Recurrent Neural Networks
[5] tensorflow学习笔记(六):LSTM 与 GRU
网友评论