美文网首页NLP&NLU
HighWay Net原理及其实现

HighWay Net原理及其实现

作者: top_小酱油 | 来源:发表于2020-04-29 11:10 被阅读0次
import torch
from torch import nn
from torch.utils import data
from torch.nn import functional as F

"""
类似于transformer 的add & norm  残差连接  当层数深的时候更容易优化
"""


class Highway(nn.Module):
    def __init__(self, size, num_layers, f):
        super(Highway, self).__init__()
        self.num_layers = num_layers
        self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
        self.f = f

    def forward(self, x):
        """
            :param x: tensor with shape of [batch_size, size]
            :return: tensor with shape of [batch_size, size]
            applies σ(x) ⨀ (f(G(x))) + (1 - σ(x)) ⨀ (Q(x)) transformation | G and Q is affine transformation,
            f is non-linear transformation, σ(x) is affine transformation with sigmoid non-linearition
            and ⨀ is element-wise multiplication
            """
        for layer in range(self.num_layers):
            gate = F.sigmoid(self.gate[layer](x))
            nonlinear = self.f(self.nonlinear[layer](x))
            linear = self.linear[layer](x)
            x = gate * nonlinear + (1 - gate) * linear
        return x

相关文章

网友评论

    本文标题:HighWay Net原理及其实现

    本文链接:https://www.haomeiwen.com/subject/dclgwhtx.html