import torch
from torch import nn
from torch.utils import data
from torch.nn import functional as F
"""
类似于transformer 的add & norm 残差连接 当层数深的时候更容易优化
"""
class Highway(nn.Module):
def __init__(self, size, num_layers, f):
super(Highway, self).__init__()
self.num_layers = num_layers
self.nonlinear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
self.linear = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
self.gate = nn.ModuleList([nn.Linear(size, size) for _ in range(num_layers)])
self.f = f
def forward(self, x):
"""
:param x: tensor with shape of [batch_size, size]
:return: tensor with shape of [batch_size, size]
applies σ(x) ⨀ (f(G(x))) + (1 - σ(x)) ⨀ (Q(x)) transformation | G and Q is affine transformation,
f is non-linear transformation, σ(x) is affine transformation with sigmoid non-linearition
and ⨀ is element-wise multiplication
"""
for layer in range(self.num_layers):
gate = F.sigmoid(self.gate[layer](x))
nonlinear = self.f(self.nonlinear[layer](x))
linear = self.linear[layer](x)
x = gate * nonlinear + (1 - gate) * linear
return x
网友评论