美文网首页PaperNote
【Papernote】ATTENTION IS ALL MY N

【Papernote】ATTENTION IS ALL MY N

作者: Chiduru | 来源:发表于2020-10-27 01:46 被阅读0次


    【Mainly Contribution】

        attention;

        mask machines;

        position embedding;

    Attention machines

    attention部分主要是scaled dot-product attention和feed forward两部分,公式如下:

    INPUT: Q,K,V

    其中,QKt计算得到一个weight score(为稳定梯度做一次归一 --> /根号dk)

    乘积矩阵作为一个weight matrix再与V做一次乘法, 过softmax, 完成一次权重重分配(是的

    其中,self-attention是Q,K,V全都有一个encode input tensor *不同的Wq,Wk, Wv得到的不同tensor,大白话就是自己乘自己

    而在transformer中是decoder层中, encoder-decoder attention是Query=dec_input,K & V=enc_output的输入,即encoder的输出跟deocder的输入做了交叉

    mask部分后面再讲。

    ScaledDotProductAttention

    (self-attention)input_vector 分别×Wq, Wk, Wv得到三个不同的tensor: Vq, Vk, Vv, then↓

    ```

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))

        attn = attn / self.temperature

        if mask is not None:

            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)

        attn = self.dropout(attn)

        output = torch.bmm(attn, v)

        return output, attn

    ```

    multi-head

    分为n_head头并行做计算后concat。个人理解源码里在一个矩阵内同时计算,permute & view, 没有拆开再做concat

    ```

    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()

        sz_b, len_k, _ = k.size()

        sz_b, len_v, _ = v.size()

        residual = q

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)   

        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)

        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k)    # (n*b) x lq x dk

        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k)    # (n*b) x lk x dk

        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v)    # (n*b) x lv x dv

        mask = mask.repeat(n_head, 1, 1)

        output, attn = self.attention(q, k, v, mask=mask)

        output = output.view(n_head, sz_b, len_q, d_v)

        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)      # b x lq x (n*dv)

        output = self.dropout(self.fc(output))

        output = self.layer_norm(output + residual)

        return output, attn

    ```

    feed forward

    ```

    def forward(self, x):

        residual = x

        output = x.transpose(1, 2)

        output = self.w_2(F.relu(self.w_1(output)))

        output = output.transpose(1, 2)

        output = self.dropout(output)

        output = self.layer_norm(output + residual)

        return output

    ```

    Mask machines

    这里的mask主要分为padding和sequence两部分:

    ① padding比较好理解, 预处理不等长数据时padding的元素位在计算时没有意义, 因此需要每次做一次padding位置零;

    ② sequence mask主要是用在encode-decode中的decoder部分,decoder在计算attention时也有两部分,一个是decoder input自身的attention计算, 一部分是encode&decode部分的attention,这里q=decode_input, key和value是encode部分的输出结果来裹一层attention。然后结合enc-output和decode input来开始逐步生成长度=n的decoder output, 对于每个i=1,2,3...n,i+1及之后的数据都是当前training的答案, 所以需要将i之后的元素(单词)置零防止标签泄露, 通常使用一个上三角矩阵。

    padding mask

    sequence mask

    position embedding

    把位置元素按照奇偶做一个sin or cos转换,加到word_tensor中, 作为顺序维度的特征embedding。

    公式:


    历史

    相关文章

      网友评论

        本文标题:【Papernote】ATTENTION IS ALL MY N

        本文链接:https://www.haomeiwen.com/subject/ejltvktx.html