一、基本原理
循环神经网络(recurrent neural network, RNN)是一类具有短期记忆能力的神经网络,其中神经元不但可以接受其它神经元的信息,也可以接受自身的信息,形成环路的网络结构。循环神经网络已被广泛用于语音识别、语言模型以及自然语言生成等任务上。
data:image/s3,"s3://crabby-images/09480/0948032efb4fed59b723501f988ce3a84e67077f" alt=""
给定一个输入序列,循环神经网络更新隐藏层活性值
,其中
,
为一个非线性函数,也可以是一个前馈网络。理论上,循环神经网络可以近似任意非线性动力系统(通用近似定理)。循环神经网络应用到机器学习中可分为以下几种模式:序列到类别模式、同步的序列到序列模式、异步的序列到序列模式。
二、门控网络
2.1、LSTM
长短期记忆(long short term memory, LSTM)网络是循环神经网络的一个变体,可以有效避免梯度消失或爆炸问题。LSTM引入门机制(gating mechanism)来控制信息传递的路径,分别为输入门,遗忘门
,输出门
。遗忘门控制上一个时刻的内部状态
需要遗忘多少信息,输入门控制当前时刻的候选状态
有多少信息需要保存,输出门控制当前时刻的内部状态
有多少信息需要输出给外部状态
。三个门的计算方式如下:
LSTM循环单元结构如下所示,其计算过程为:(1)首先利用上一时刻的外部状态和当前时刻的输入
,计算出三个门,以及候选状态
;(2)结合遗忘门
和输入门
来更新记忆单元
;(3)结合输出门
,将内部状态的信息传递给外部状态
。
data:image/s3,"s3://crabby-images/5da2c/5da2c7a747eba0fc8a4acea5227cc585fdb16e7f" alt=""
通过LSTM循环单元,整个网络可以建立较长距离的时序依赖关系。描述如下:
循环神经网络中的隐状态存储了历史信息,可以看作是一种记忆(memory) 。在简单循环网络中,隐状态每个时刻都会被重写,因此可以看作是一种短期记忆(short-term memory)。在神经网络中,长期记忆(long-term memory) 可以看作是网络参数,隐含了从训练数据中学到的经验,并更新周期要远远慢于短期记忆。而在LSTM网络中,记忆单元
可以在某个时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔。记忆单元
中保存信息的生命周期要长于短期记忆
,但又远远短于长期记忆,因此称为长的短期记忆。
2.2、GRU
门控循环单元(gated recurrent unit, GRU)网络是一种比LSTM更简单的循环神经网络。GRU网络也是引入门机制来控制信息更新的方式。在LSTM网络中,输入门和遗忘门是互补关系,用两个门比较冗余。GRU将输入门与和遗忘门合并成一个门:更新门。同时,GRU也不引入额外的记忆单元, 直接在当前状态和历史状态
之间引入线性依赖关系。
更新门 ,用来控制当前状态需要从历史状态中 保留多少信息(不经过非线性变换),以及需要从候选状态中接受多少新信息。当
时,当前状态
和历史状态
之间为非线性函数。若同时有
时,GRU网络退化为简单循环网络;若同时有
时,当 前状态
只和当前输入
相关,和历史状态
无关。当
时,当前状 态
等于上一时刻状态
,和当前输入
无关。
重置门 ,当
时,候选状态
只和当前输入
相关,和历史状态无关。当
时,候选状态
和当前输入
和历史状态
相关,和简单循环网络一致。
GRU循环单元结构如下:
data:image/s3,"s3://crabby-images/7dd73/7dd73745e694faec267b3a65d5ae58ee2ccba873" alt=""
三、算法实现
本案例将通过循环神经网络对世界人均GDP的预测来感受其作用过程。
1、数据获取。读取世界银行网站人均GDP数据,选取11个国家,48年的数据,数据格式如下图所示。
from pandas_datareader import wb
countries = ['BR','CA','CN','FR','DE','IN','IL','JP','SA','GB','US']
data = wb.download(indicator='NY.GDP.PCAP.KD',country=countries,start=1970,end=2018)
df = data.unstack().T
df.index = df.index.droplevel(0)
data:image/s3,"s3://crabby-images/bca8b/bca8b3645d9c4bac9a34743a054a0dcab65ee4e6" alt=""
2、搭建神经网络。神经网络包含两个部分:LSTM和全连接层。LSTM为简单单层结构,隐藏张量再通过全连接层线性变换,得到后一年的预测值。
import torch
import torch.nn
class Net(torch.nn.Module):
def __init__(self,input_size,hidden_size):
super(Net,self).__init__()
self.rnn = torch.nn.LSTM(input_size,hidden_size)
self.fc = torch.nn.Linear(hidden_size,1)
def forward(self,x):
x = x[:,:,None]
x,_ = self.rnn(x)
x = self.fc(x)
x = x[:,:,0]
return x
3、训练过程。把1971-2000年数据当做训练数据,把2001-2018年数据当做测试数据,训练前需要对数据进行归一化。如下图所示,在训练过程中,误差逐步减小。
lstm = Net(input_size=1,hidden_size=5)
df_scaled = df/df.loc['2000']
years = df.index
train_seq_len = sum((years>='1971') & (years<='2000'))
test_seq_len = sum(years>'2000')
print('length of train:{},length of test:{}'.format(train_seq_len,test_seq_len))
inputs = torch.tensor(df_scaled.iloc[:-1].values,dtype=torch.float32)
labels = torch.tensor(df_scaled.iloc[1:].values,dtype=torch.float32)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(lstm.parameters())
train_loss_list,test_loss_list = [],[]
for step in range(10001):
preds = lstm(inputs)
train_preds = preds[:train_seq_len]
train_labels = labels[:train_seq_len]
train_loss = criterion(train_preds,train_labels)
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
test_preds = preds[-test_seq_len]
test_labels = labels[-test_seq_len]
test_loss = criterion(test_preds,test_labels)
train_loss_list.append(train_loss)
test_loss_list.append(test_loss)
if step % 500 == 0:
print('epoch:{},train loss:{},test loss:{}'.format(step,train_loss,test_loss))
import matplotlib.pyplot as plt
plt.plot(train_loss_list)
plt.plot(test_loss_list)
plt.show()
data:image/s3,"s3://crabby-images/6667b/6667b129e86d301cf931b378219152a4b4c73dcb" alt=""
4、预测过程。利用训练好的模型预测新的值,下图展示了2000年后的预测值的数值。
from IPython.display import display
preds = lstm(inputs)
df_pred_scaled = pd.DataFrame(preds.detach().numpy(),index=years[1:],columns=df.columns)
df_pred = df_pred_scaled * df.loc['2000']
display(df_pred.loc['2001':])
data:image/s3,"s3://crabby-images/f75a6/f75a67768438e12f1ce8d0600a22c5ebc506c12c" alt=""
为更加形象比较真实值和预测值的差异,以下显示了美国和中国的数据情况,可以看出美国的预测值比较精准,而中国的数据中2017,2018预测值偏差较大,这一方面有可能是由于数据本身存在统计错误,也有可能是中国今年积极推进扶贫工作的结果。
data:image/s3,"s3://crabby-images/5f8db/5f8db5f17f17dd3b0b4d5667687aff2c30f34a38" alt=""
data:image/s3,"s3://crabby-images/2b912/2b91277e666a3d6a3858a8e726b1d95073f391ad" alt=""
参考资料
[1] Vishnu Subramanian. Deep Learning with PyTorch. Packet Publishing. 2018.
[2] 邱锡鹏 著,神经网络与深度学习. https://nndl.github.io/ 2019.
[3] 肖智清 著,神经网络与PyTorch实战. 北京:机械工业出版社. 2018.
[4] 唐进民 编著,深度学习之PyTorch实战计算机视觉. 北京:电子工业出版社. 2018.
[5] Ian Goodfellow 等 著, 赵申剑等 译, 深度学习. 北京:人民邮电出版社, 2017.
吾心自有光明月,千古团圆永无缺。——王守仁《中秋》
网友评论