《十二时辰》教你直观理解 Position-Encoding

作者: huge823 | 来源:发表于2019-08-25 12:00 被阅读0次

    TL;DR

    • paper中的位置编码定义可以直观理解为 "钟表盘上每个针头的位置坐标"
    • 跟直接拿index作为位置编码的方案相比,这种定义有两个优点
      • 可以使用不含bias的线性变换来表征\Delta t,从而便于模型attend到相对位置[1]
      • 各维度的周期互相交错,表征能力为其最小公倍数;能对训练数据中从没见过的更长位置信息加以编码

    问题由来

    深度学习中有个著名的 Transformer模型 (2017年Google Brain发表那篇《Attention Is All You Need》[2]) ,其中有个设计得奇形怪状的 “Position Encoding” 一直不太好理解

    定义成下面这个样子(一脸懵逼,有木有……)

    \begin{aligned} PE_{(pos,2i)} &= sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos,2i+1)} &= cos(pos/10000^{2i/d_{model}}) \end{aligned}

    关于这个 Position-Encoding的定义,有以下几点疑惑:

    1. 这种复杂的定义,真的能编码 位置信息吗?
    2. 为啥不直接用下标,而是整个这么复杂的定义?

    巧的是,我昨晚看《长安十二时辰》的时候,联想到“天干地支”,
    似乎对于这个定义有了 比较直观的理解 ;且待我与列位看官慢慢道来。

    真能编码位置信息

    干支纪年

    不妨回忆一下老祖宗的“天干地支”纪年法(类似两个齿轮):

    • 以"十天干"(“甲、乙、丙、丁、戊、己、庚、辛、壬、癸”)为 第一维度
    • 以"十二地支"(“子、丑、寅、卯、辰、巳、午、未、申、酉、戌、亥”)为 第二维度
    • 两个维度 绑定在一起 循环滚动,但是周期不同(分别为10/12)
      • e.g. (简书的markdown-latex在移动端不支持中文字符,只好截图如下)
      • 两维度若独立滚动,排列组合就是 10\times12=120 \neq 60 ,与常识冲突

    • 表征能力是 两个周期的最小公倍数 lcm(10,12)=60
    天干地支
    (图片来自网络 http://oraclebonescriptdiary.blogspot.com/2013/08/blog-post.html)

    钟表计时

    类似的,看看机械钟表的表盘

    • 这是机械钟而非电子钟,可以认为 时针与分钟的行程都是连续(而非离散)的
    • 忽略纪年和计时的区别;时针与分针 类似与 天干和地支
      • 时针与分针 滚动速度(频率)不同,前者快60倍
      • 频率是周期的倒数,也可认为两者是 绑定在一起 滚动,但是周期不同
    • 由于时针的周期恰好能被分针整除,故该钟的表征能力 等于一根时针
    机械钟的盘面
    (图片来自京东 https://item.jd.com/54349670287.html)
    (注:图片仅用于研究目的,不带货哈)

    单位圆上的点坐标表征与旋转

    回顾一下高中数学:

    • 单位圆上的任意点坐标可以表达为 (x,y) = \left(\cos(\varphi), \sin(\varphi)\right) 的形式
    • 不同的周期类比于不同的表针 (时、分、秒 ...) / 或者理解为 天干和地支
    • 任意一根表针的 针头坐标(x,y)旋转\theta角,皆可用 一个 2\times 2的 仅与 旋转角度\theta有关、而与起点位置(x,y)无关的矩阵表达
      \left[\begin{array}{c}{x^{\prime}} \\ {y^{\prime}}\end{array}\right]=\left[\begin{array}{cc}{\cos \theta} & {-\sin \theta} \\ {\sin \theta} & {\cos \theta}\end{array}\right] *\left[\begin{array}{l}{x} \\ {y}\end{array}\right]
    • 多根表针的情况类似,用更大的矩阵可以表达

    因此,时间"10:55"也可以(冗余)表达为
    \begin{aligned} &\left(hour_x, hour_y, min_x, min_y \right) \\ = &\left( \cos\left(\frac{10}{12}\cdot 2\pi \right), \sin\left(\frac{10}{12}\cdot 2\pi \right), \cos\left(\frac{55}{60}\cdot 2\pi \right), \sin\left(\frac{55}{60}\cdot 2\pi \right) \right) \\ \triangleq &\left( \cos\left(\frac{10}{f_1}\right), \sin\left(\frac{10}{f_1}\right), \cos\left(\frac{55}{f_2}\right), \sin\left(\frac{55}{f_2}\right) \right) \\ \end{aligned}

    Transformer模型中的位置编码

    再看看 Transformer模型中的 Position-Encoding 定义:

    \begin{aligned} PE_{(pos,2i)} &= sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos,2i+1)} &= cos(pos/10000^{2i/d_{model}}) \end{aligned}

    其中

    • pos代表序列内维度(第几帧)
    • 2i/2i+1 分别代表PE的奇数/偶数维度(位置编码向量的第几维)
    • 从上述矩阵中切片某一列(viz. 固定列坐标,只看PE的某一个维度),并将pos简写为t,得到下述列向量
      \begin{bmatrix} \sin\left(\frac{t}{f_1}\right)\\ \cos\left(\frac{t}{f_1}\right)\\ \sin\left(\frac{t}{f_2}\right)\\ \cos\left(\frac{t}{f_2}\right)\\ \vdots\\ \sin\left(\frac{t}{f_{\frac{d_\text{model}}{2}}}\right)\\ \cos\left(\frac{t}{f_{\frac{d_\text{model}}{2}}}\right) \end{bmatrix}

    恰好就是在描述 d_{\text{model}}/2根针构成的表盘上,各针头的坐标
    显然,每个针头的坐标都清楚了,自然有能力表征位置信息(甚至有点维度冗余)

    整这么复杂的定义,有道理

    上文说过,表针的旋转可以使用不含bias的矩阵来表达;复述如下:

    • 任意一根表针的 针头坐标(x,y)旋转\theta角,皆可用 一个 2\times 2的 仅与 旋转角度\theta有关、而与起点位置(x,y)无关的矩阵表达
      \left[\begin{array}{c}{x^{\prime}} \\ {y^{\prime}}\end{array}\right]=\left[\begin{array}{cc}{\cos \theta} & {-\sin \theta} \\ {\sin \theta} & {\cos \theta}\end{array}\right] *\left[\begin{array}{l}{x} \\ {y}\end{array}\right]
    • 多根表针的情况类似,用更大的矩阵可以表达

    矩阵乘法本质是线性变换,对应于Dense层; 而 t_0 + \Delta t帧的位置向量可以表示为 t_0帧 的位置向量的线性变换(只需旋转,无需偏置 ; 详见证明[3]).
    这为模型捕捉单词之间的相对位置关系提供了便利,使得模型能够 方便地attend到相对时刻

    思考 self-attention-layer 中 Q/K/V 的定义

    • 每一帧都是先做word_embedding("我是啥语义"),然后加上positional_embedding("我在哪一帧"),然后使用 K 矩阵做线性变换,得到代表该帧作为key的向量. (Q/V 类似同理)
    • 解码到t帧时,对t-3帧的attent程度为 Q^T_{t} K_{t-3}; 是点积形式, 两向量取值越接近,点积越大
    • QK分别都是 dense(word_embedding+positional_embedding) 的形式;若能在dense中学到rotate(-3)的关系,即可使两个向量非常接近,点积很大

    反之,如果直接使用下标作为位置的定义,
    则"相对时间"的概念要通过 含有bias的"仿射变换(viz. 线性变换+平移)"才能表达

    另外,如上文所述的“干支纪年法”类似;各个维度的周期不能整除,可以表征的范式是各维度的最小公倍数。
    这样就能在inference时,对训练数据中从没见过的更长位置信息加以编码

    下面给出 MXNet中位置编码的教学实现[4];工程实现类似,参见 GitHub

    class PositionalEncoding(nn.Block):
        def __init__(self, units, dropout, max_len=1000):
            super(PositionalEncoding, self).__init__()
            T = nd.arange(0, max_len).reshape((-1,1)) / nd.power(
                10000, nd.arange(0, units, 2)/units) # 临时矩阵T
            self.P = nd.zeros((1, max_len, units)) # 注意P是常数矩阵,无需训练
            self.P[:, :, 0::2] = nd.sin(T) # 偶数下标行 取sin
            self.P[:, :, 1::2] = nd.cos(T)
            self.dropout = nn.Dropout(dropout)
    
        def forward(self, X):
            # P是常数矩阵,直接截取需要的形状部分 然后与X相加即可
            X = X + self.P[:, :X.shape[1], :].as_in_context(X.context)
            return self.dropout(X)
    

    尚存几点疑惑

    1. 位置编码为啥要跟语义向量add到一起,而非concat;上文中的旋转矩阵作用到语义部分,担心会有副作用
    2. 相对位置可以表达为线性变换,凭啥就有利于模型学习? paper原文中语焉不详,上述 K/Q/V的解释只是我的一种猜想,也不能很严谨的说服自己
    3. 更多欢迎留言讨论

    Reference


    1. "详解Transformer (Attention Is All You Need)", 刘岩, 2018, https://zhuanlan.zhihu.com/p/48508221

    2. Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017. https://arxiv.org/abs/1706.03762

    3. "Linear Relationships in the Transformer’s Positional Encoding" Timo Denk's Blog, Timo Denk, 2019, https://timodenk.com/blog/linear-relationships-in-the-transformers-positional-encoding/

    4. "9.3 Transformer" Dive into Deep Learning, Mu Li, et al, http://en.d2l.ai/chapter_attention-mechanism/transformer.html

    相关文章

      网友评论

        本文标题:《十二时辰》教你直观理解 Position-Encoding

        本文链接:https://www.haomeiwen.com/subject/fmaqectx.html