美文网首页
底层视觉中的transformer结构

底层视觉中的transformer结构

作者: lishuoshi1996 | 来源:发表于2022-03-21 11:26 被阅读0次

    1、Restormer: Efficient Transformer for High-Resolution Image Restoration

    (1)总体结构

    总体结构与U-net相近。(a)展示了作者改进的multi-Dconv head transposed attention ,其主要做法是将空间的attention,转移到通道上,从而可以处理高分辨率图像。(b)是作者改进的Gated-Dconv feed-forward network ,算是一个锦上添花的改进。

    Restormer的总体结构

    (2)参数设置

    对应图中的L1-L4,分别取值为4,6,6,8。attention heads的数目依次为 1,2,4,8,特征的通道数依次为 48,96,192,384。L_r的取值为4。优化器AdamW,学习率由3e-4降至1e-6,使用cosine annealing策略。此外,使用渐进学习方式,在不同的epoch,图像大小不断增大,batchsize数目不断变小。最后,使用了 horizontal and vertical flips数据增强。

    总体来说,模型训练有许多trick。从表7来看,模型Flops相对较小,但是参数量较大。

    (3)代码实现

    A. Multi-DConv Head Transposed Self-Attention (MDTA)

    ## Multi-DConv Head Transposed Self-Attention (MDTA)

    class Attention(nn.Module):

        def __init__(self, dim, num_heads, bias):

            super(Attention, self).__init__()

            self.num_heads = num_heads #这里是attention的head数目

            self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

            self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) # q*w,K*w,v*w

            self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) # 可分离卷积

            self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

        def forward(self, x):

            b,c,h,w = x.shape

            qkv = self.qkv_dwconv(self.qkv(x))

            q,k,v = qkv.chunk(3, dim=1) 

            q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

            k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

            v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

            q = torch.nn.functional.normalize(q, dim=-1)

            k = torch.nn.functional.normalize(k, dim=-1)

            attn = (q @ k.transpose(-2, -1)) * self.temperature

            attn = attn.softmax(dim=-1)

            out = (attn @ v)

            out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

            out = self.project_out(out)

            return out

    B. Gated-Dconv Feed-Forward Network (GDFN)

    class FeedForward(nn.Module):

        def __init__(self, dim, ffn_expansion_factor, bias):

            super(FeedForward, self).__init__()

            hidden_features = int(dim*ffn_expansion_factor)

            self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

            self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

            self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

        def forward(self, x):

            x = self.project_in(x)

            x1, x2 = self.dwconv(x).chunk(2, dim=1)

            x = F.gelu(x1) * x2

            x = self.project_out(x)

    相关文章

      网友评论

          本文标题:底层视觉中的transformer结构

          本文链接:https://www.haomeiwen.com/subject/pjewdrtx.html