(一)门控制循环单元GRU
这是一个在LSTM 之后发出的神经网络,他和LSTM模型性能差不多,但是结构要更加简单。
我们在实现RNN的时候我们时使用state这个变量来存储隐藏信息,然而这个隐藏信息太小不能很好的掌握序列的隐藏信息,隐藏层太大又容易过拟合。这决定了RNN很难处理长序列问题。
我们发现,不是每一个观察值(序列值)都是同等重要的。
GRU的核心是说,神经网络可以有选择性的在序列中关注重要信息,忽略非重要信息。
所以他设置了两个门控电路:
- 更新门(update):就是说他觉得这个观察值是很重要的,于是开门让他来更新state值
- 遗忘门/重置门(reset):神经网络判断前面的序列是否还是重要的,如果不重要,可以初始化state,遗忘掉前面的信息。
(1)门
我们可以将门理解为和隐藏状态一样长的一个向量。激活函数使用sigmoid。计算方式如图所示:
![](https://img.haomeiwen.com/i12824314/8dc95d5c4e22d4d4.png)
(2)候选隐藏状态
公式中对了一条线路就是由于两个矩阵的大小是一致的,这里做的是对应元素的乘法。
属于(0,1),所以两者的乘积在如果R中的元素较小的话,与H中的元素相乘,就会使得结果接近于零,从而在后续的计算中失去作用。这就相当于神经网络把这个参数给忘记了。这就是为什么
被叫做遗忘门的原因,他的概念就像是电路设计里的门电路,与门。
![](https://img.haomeiwen.com/i12824314/e4fe75b600d633e6.png)
(3)隐状态
这里就是类似一个加权平均,拿一部分过去的信息,拿一部分更新后的候选码的信息共同作用生成一个结果。
![](https://img.haomeiwen.com/i12824314/235f3be1c859b498.png)
![](https://img.haomeiwen.com/i12824314/20b30342da3fc4c7.png)
(二)代码实现
import torch
from torch import nn
from d2l import torch as d2l
batch_size, num_step = 32,35
train_iter, vocab = d2l.load_data_time_machine(batch_size,num_step)
# 初始化模型参数
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(shape, device=device) * 0.01
def three():
return (
normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens,device=device)
)
# 门电路的参数
W_xz, W_hz, b_z = three() # 更新门参数
W_xr, W_hr, b_r = three() # 重置门参数
W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
# 初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens),device=device),)
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
![](https://img.haomeiwen.com/i12824314/3baa2f105b4e45fa.png)
网友评论