美文网首页
Multi-head attention 多头注意力机制

Multi-head attention 多头注意力机制

作者: 锦绣拾年 | 来源:发表于2021-06-20 22:07 被阅读0次

    Multi-head attention

    本文基于《dive into deep learning》-pytorch

    代码参考 《dive into deep learning》-pytorch

    multi-head attention

    基本信息

    我们可以会希望注意力机制可以联合使用不同子空间的key,value,query的表示。因此,不是只用一个attention pooling,query、key、value可以被h个独立学到的线性映射转换。最后,h个attention pooling输出concat 并且再次通过一个线性映射得到最后的输出。

    这种设计就是multi-head attention, h个attention pooling输出中的每一个就是一个头。使用全连接层来实现线性转换。

    multi-attention1.png

    理解纠错

    【我过去有一个误解,就是multi-head是和CNN类似的机制,用多个的W降维,之后再计算多个注意力分数,再concat。直到我用pytorch中自带的multi-head attention,要求num_heads是hidden层维度可以整除的数,才发现这里的multi-head是针对子空间的】
    【但是这里可以理解,用同样的维度,训练多个空间,可以更好地增强表达能力】

    这部分解答可以参考:

    https://www.zhihu.com/question/350369171 -transformer中multi-head attention中每个head为什么要进行降维?(实际上用切割来表示更为准确)

    https://www.zhihu.com/question/446385446 - BERT中,multi-head 768*64*12与直接使用768*768矩阵统一计算,有什么区别?

    对于 Multi-Head Attention,简单来说就是多个 Self-Attention 的组合,但多头的实现不是循环的计算每个头,而是通过 transposes and reshapes,用矩阵乘法来完成的。

    In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. —— 来自 google BERT 源代码注释

    Transformer中把 d ,也就是hidden_size/embedding_size这个维度做了reshape拆分,具体可以看对应的 pytorch 代码

    hidden_size (d) = num_attention_heads (m) * attention_head_size (a),也即 d=m*a

    【↑作者:海晨威
    链接:https://www.zhihu.com/question/350369171/answer/1718672303
    来源:知乎
    著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。】

    正如回答中有写:

    transformer中multi-head attention中每个head为什么要进行降维? - LooperXX的回答 - 知乎 https://www.zhihu.com/question/350369171/answer/860552006

    回到题主的问题上来,如果只使用 one head 并且维度为 d_model ,相较于 8 head 并且维度为d_model/8,存在高维空间下学习难度较大的问题,文中实验也证实了这一点,于是将原有的高维空间转化为多个低维子空间并再最后进行拼接,取得了更好的效果,十分巧妙。

    在实现的时候,multi-head把维度从[batch, len, embeding]变为[batch, len, head, embeding/head], 然后head就是多头,对每一个 embeding/head部分计算对应的attention。

    pytorch实现

    class MultiHeadAttention(nn.Module):
        def __init__(self, key_size, query_size, value_size, num_hiddens,
    num_heads, dropout, bias=False, **kwargs):
            super(MultiHeadAttention, self).__init__(**kwargs)
            self.num_heads = num_heads
            self.attention = DotProductAttention(dropout)
            self.W_q = nn.Linear(query_size,num_hiddens,bias=bias)
            self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
            self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)#映射到numhiddens
            self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        def forward(self,queries,keys,values,valid_lens):
            #注意最后的 [batch_size` * `num_heads`,number of  key-value pairs,num_hiddens` / `num_heads]
            #这里涉及到reshape操作
            queries = transpose_qkv(self.W_q(queries),self.num_heads)#batch,seq,embed -> batch*num_head,seq,embed/num_head
            keys = transpose_qkv(self.W_k(keys),self.num_heads)
            values = transpose_qkv(self.W_v(values), self.num_heads)
            if valid_lens is not None:#相当于每个batch扩充num_heads遍
                valid_lens = torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)
            print(queries.shape)#10,4,20
            print(values.shape)#10,6,20
            print(")*&^%$^&*()")
            output = self.attention(queries, keys, values, valid_lens)#attention计算是transpose之后的向量
            #得到,batch×head, seq,embed/head的矩阵,每一个embed/head是这一部分词向量子空间的attention加权和。
            weights= self.attention.attention_weights
            print(weights.shape)#10,4,6  query: 2 4 100 key: 2,6,100 ,一共10组,每组 4×6,query和key的交互值
            
            output_concat = transpose_output(output, self.num_heads)#transpose的逆运算
            return self.W_o(output_concat)#最后做一次线性变换 #2,4,100
            
    
    def transpose_qkv(X, num_heads):
        X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)#batch seq head embed/head
    
        X = X.permute(0, 2, 1, 3) # batch head seq embed/head
    
        return X.reshape(-1, X.shape[2], X.shape[3])# batch×head, seq,embed/head
    
    def transpose_output(X, num_heads):# batch×head, seq,embed/head
        """Reverse the operation of `transpose_qkv`"""
        X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# batch,head, seq,embed/head
        X = X.permute(0, 2, 1, 3)## batch,seq,head, embed/head
        return X.reshape(X.shape[0], X.shape[1], -1)#batch,seq,embed
    
    num_hiddens, num_heads = 100, 5
    attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
    attention.eval()
    batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
    X = torch.ones((batch_size, num_queries, num_hiddens))#2,4,100
    Y = torch.ones((batch_size, num_kvpairs, num_hiddens))#2,6,100
    attention(X, Y, Y, valid_lens).shape #2,4,100 query有4个,得到4个对应的结果
    #中间 attention weight大小是 10,4,6
    

    self-attention

    输入和输出大小一样

    query,key,value一样

    batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
    X = torch.ones((batch_size, num_queries, num_hiddens))#2,4,100
    attention(X, X, X, valid_lens).shape #2,4,100
    
    • 自注意力机制中,query,key,value来自于相同的空间
    • CNN和self-attention都有利于并行运算,self-attention有要求最短的最大路径长度。但是由于复杂度是序列长度的平方,长序列会计算比较慢。
    • 为了使用序列顺序信息,我们可以通过向输入表示添加位置编码注入绝对位置或相对位置信息,如transformer的 PositionalEncoding

    补充
    Transformer/CNN/RNN的对比(时间复杂度,序列操作数,最大路径长度) - Gordon Lee的文章 - 知乎
    https://zhuanlan.zhihu.com/p/264749298
    https://spaces.ac.cn/archives/4765
    ↑对self-attention的分析也很好,self-attention有不能充分编入位置信息的硬伤等

    相关文章

      网友评论

          本文标题:Multi-head attention 多头注意力机制

          本文链接:https://www.haomeiwen.com/subject/uxfeyltx.html