美文网首页
通俗理解自注意力(self-attention)

通俗理解自注意力(self-attention)

作者: Black先森 | 来源:发表于2020-06-08 10:37 被阅读0次

    谷歌在2017年发表了一篇论文《Attention Is All You Need》,论文中提出了transformer模型,其核心就是self-attention的架构,这一突破性成果不仅洗遍了NLP的任务,也在CV中取得了非常好的效果,有大道至简的感觉。本文通过一个通俗易懂的例子[1]来介绍self-attention。

    文章首发个人博客
    (注:本文例子完全来在参考文章,包括文章的gif动图,感谢作者的文章)

    介绍

    接下来将通过一下几个步骤来介绍:

    1. 预处理输入数据
    2. 初始化权重
    3. 计算key,query 和value
    4. 计算输入值的注意力得分
    5. 计算softmax层
    6. 注意力得分与value相乘
    7. 对6中结果加权求和,并得到第一个输出值
    8. 重复4-7,计算其余输入数据的输出值

    预处理输入数据


    本例中我们选择三个输入值,已经通过embedding处理,得到了三个词向量。

    Input 1: [1, 0, 1, 0] 
    Input 2: [0, 2, 0, 2]
    Input 3: [1, 1, 1, 1]
    

    初始化权重

    权重包括三个,分别是query的W_q,key的W_k以及value的W_v,例如这三个权重分别初始化为

    W_k矩阵为:

    [[0, 0, 1],
     [1, 1, 0],
     [0, 1, 0],
     [1, 1, 0]]
    

    W_q矩阵为:

    [[1, 0, 1],
     [1, 0, 0],
     [0, 0, 1],
     [0, 1, 1]]
    

    W_v矩阵为:

    [[0, 2, 0],
     [0, 3, 0],
     [1, 0, 3],
     [1, 1, 0]]
    

    计算key,query 和value

    有了输入和权重,接下来可以计算每个输入对应的key,query 和value了。

    第一个输入的Key为:

                   [0, 0, 1]
    [1, 0, 1, 0] x [1, 1, 0] = [0, 1, 1]
                   [0, 1, 0]
                   [1, 1, 0]
    

    第二个输入的Key为:

                   [0, 0, 1]
    [0, 2, 0, 2] x [1, 1, 0] = [4, 4, 0]
                   [0, 1, 0]
                   [1, 1, 0]
    

    第三个输入的Key为:

                   [0, 0, 1]
    [1, 1, 1, 1] x [1, 1, 0] = [2, 3, 1]
                   [0, 1, 0]
                   [1, 1, 0]
    

    用矩阵的乘法来计算输入的Key为:

                   [0, 0, 1]
    [1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]
    [0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
    [1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]
    

    同理我们计算value的结果为:

                   [0, 2, 0]
    [1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3] 
    [0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
    [1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]
    

    最后我们计算query的结果:

                   [1, 0, 1]
    [1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]
    [0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
    [1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]
    

    计算输入值的注意力得分

    注意力的得分是通过query与每个key结果相乘。例如对于第一个query(红色)分别与三个key(橙色)相乘,得到结果(蓝色)就是注意力得分。



    计算结果为:

                [0, 4, 2]
    [1, 0, 2] x [1, 4, 3] = [2, 4, 4]
                [1, 0, 1]
    

    计算softmax层

    softmax函数直接对上一步中的注意力得分做归一化处理。

    softmax([2, 4, 4]) = [0.0, 0.5, 0.5]
    

    得分与value相乘

    得到的每个得分值与自身的value直接相乘

    1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
    2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
    3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]
    

    对6中结果求和,并得到第一个输出值

    上一步骤中输出结果求和就得到第一个输出值

      [0.0, 0.0, 0.0]
    + [1.0, 4.0, 0.0]
    + [1.0, 3.0, 1.5]
    -----------------
    = [2.0, 7.0, 1.5]
    

    重复4-7,计算其余输入数据的输出值

    重复计算4-7,分别得到第二个和第三个输出值



    于是三个输入经过self-attention模块,得到了三个输出值。这就是attention模块做的事情,是不是很简单。《Attention Is All You Need》论文中的attention计算公式:

    Attention(Q,K, V) = softmax(\frac{QK^T}{\sqrt d_k})V

    attention最厉害的地方在于能够捕捉到全局信息,经过这个模块的输出结果,是通过输入结果两两运算得出了权重,再对输入进行加权求和得到了。除了捕捉全局信息,还能并行计算,这就比之前的RNN和CNN厉害多了,怪不得谷歌给这篇论文起名叫做Attention Is All You Need,有这个attention就够了。

    参考


    1. https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a

    相关文章

      网友评论

          本文标题:通俗理解自注意力(self-attention)

          本文链接:https://www.haomeiwen.com/subject/nebpohtx.html