美文网首页
门控制循环单元GRU

门控制循环单元GRU

作者: 小黄不头秃 | 来源:发表于2022-10-01 01:31 被阅读0次

(一)门控制循环单元GRU

这是一个在LSTM 之后发出的神经网络,他和LSTM模型性能差不多,但是结构要更加简单。

我们在实现RNN的时候我们时使用state这个变量来存储隐藏信息,然而这个隐藏信息太小不能很好的掌握序列的隐藏信息,隐藏层太大又容易过拟合。这决定了RNN很难处理长序列问题。

我们发现,不是每一个观察值(序列值)都是同等重要的。

GRU的核心是说,神经网络可以有选择性的在序列中关注重要信息,忽略非重要信息。
所以他设置了两个门控电路:

  • 更新门(update):就是说他觉得这个观察值是很重要的,于是开门让他来更新state值
  • 遗忘门/重置门(reset):神经网络判断前面的序列是否还是重要的,如果不重要,可以初始化state,遗忘掉前面的信息。
(1)门

我们可以将门理解为和隐藏状态一样长的一个向量。激活函数使用sigmoid。计算方式如图所示:

(2)候选隐藏状态

公式中对了一条线路就是R_t * H_{t-1}由于两个矩阵的大小是一致的,这里做的是对应元素的乘法。R_t属于(0,1),所以两者的乘积在如果R中的元素较小的话,与H中的元素相乘,就会使得结果接近于零,从而在后续的计算中失去作用。这就相当于神经网络把这个参数给忘记了。这就是为什么R_t被叫做遗忘门的原因,他的概念就像是电路设计里的门电路,与门。

(3)隐状态

这里就是类似一个加权平均,拿一部分过去的信息,拿一部分更新后的候选码的信息共同作用生成一个结果。

(二)代码实现

import torch 
from torch import nn 
from d2l import torch as d2l

batch_size, num_step = 32,35
train_iter, vocab = d2l.load_data_time_machine(batch_size,num_step)
# 初始化模型参数
def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(shape, device=device) * 0.01
    
    def three():
        return (
            normal((num_inputs, num_hiddens)),
            normal((num_hiddens, num_hiddens)),
            torch.zeros(num_hiddens,device=device)
        )

    # 门电路的参数
    W_xz, W_hz, b_z = three()  # 更新门参数
    W_xr, W_hr, b_r = three()  # 重置门参数
    W_xh, W_hh, b_h = three()  # 候选隐状态参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
# 初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens),device=device),)
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

相关文章

网友评论

      本文标题:门控制循环单元GRU

      本文链接:https://www.haomeiwen.com/subject/lpheortx.html