TL;DR
- paper中的位置编码定义可以直观理解为 "钟表盘上每个针头的位置坐标"
- 跟直接拿index作为位置编码的方案相比,这种定义有两个优点
- 可以使用不含bias的线性变换来表征,从而便于模型attend到相对位置[1]
- 各维度的周期互相交错,表征能力为其最小公倍数;能对训练数据中从没见过的更长位置信息加以编码
问题由来
深度学习中有个著名的 Transformer模型 (2017年Google Brain发表那篇《Attention Is All You Need》[2]) ,其中有个设计得奇形怪状的 “Position Encoding” 一直不太好理解
定义成下面这个样子(一脸懵逼,有木有……)
关于这个 Position-Encoding的定义,有以下几点疑惑:
- 这种复杂的定义,真的能编码 位置信息吗?
- 为啥不直接用下标,而是整个这么复杂的定义?
巧的是,我昨晚看《长安十二时辰》的时候,联想到“天干地支”,
似乎对于这个定义有了 比较直观的理解 ;且待我与列位看官慢慢道来。
真能编码位置信息
干支纪年
不妨回忆一下老祖宗的“天干地支”纪年法(类似两个齿轮):
- 以"十天干"(“甲、乙、丙、丁、戊、己、庚、辛、壬、癸”)为 第一维度
- 以"十二地支"(“子、丑、寅、卯、辰、巳、午、未、申、酉、戌、亥”)为 第二维度
- 两个维度 绑定在一起 循环滚动,但是周期不同(分别为10/12)
- e.g. (简书的markdown-latex在移动端不支持中文字符,只好截图如下)
-
两维度若独立滚动,排列组合就是 ,与常识冲突
- 表征能力是 两个周期的最小公倍数
(图片来自网络 http://oraclebonescriptdiary.blogspot.com/2013/08/blog-post.html)
钟表计时
类似的,看看机械钟表的表盘
- 这是机械钟而非电子钟,可以认为 时针与分钟的行程都是连续(而非离散)的
- 忽略纪年和计时的区别;时针与分针 类似与 天干和地支
- 时针与分针 滚动速度(频率)不同,前者快60倍
- 频率是周期的倒数,也可认为两者是 绑定在一起 滚动,但是周期不同
- 由于时针的周期恰好能被分针整除,故该钟的表征能力 等于一根时针
(图片来自京东 https://item.jd.com/54349670287.html)
(注:图片仅用于研究目的,不带货哈)
单位圆上的点坐标表征与旋转
回顾一下高中数学:
- 单位圆上的任意点坐标可以表达为 的形式
- 不同的周期类比于不同的表针 (时、分、秒 ...) / 或者理解为 天干和地支
- 任意一根表针的 针头坐标旋转角,皆可用 一个 的 仅与 旋转角度有关、而与起点位置无关的矩阵表达
- 多根表针的情况类似,用更大的矩阵可以表达
因此,时间"10:55"也可以(冗余)表达为
Transformer模型中的位置编码
再看看 Transformer模型中的 Position-Encoding 定义:
其中
- 代表序列内维度(第几帧)
- / 分别代表PE的奇数/偶数维度(位置编码向量的第几维)
- 从上述矩阵中切片某一列(viz. 固定列坐标,只看PE的某一个维度),并将简写为,得到下述列向量
恰好就是在描述 根针构成的表盘上,各针头的坐标
显然,每个针头的坐标都清楚了,自然有能力表征位置信息(甚至有点维度冗余)
整这么复杂的定义,有道理
上文说过,表针的旋转可以使用不含bias的矩阵来表达;复述如下:
- 任意一根表针的 针头坐标旋转角,皆可用 一个 的 仅与 旋转角度有关、而与起点位置无关的矩阵表达
- 多根表针的情况类似,用更大的矩阵可以表达
矩阵乘法本质是线性变换,对应于Dense
层; 而 帧的位置向量可以表示为 帧 的位置向量的线性变换(只需旋转,无需偏置 ; 详见证明[3]).
这为模型捕捉单词之间的相对位置关系提供了便利,使得模型能够 方便地attend到相对时刻
思考 self-attention-layer 中 Q/K/V 的定义
- 每一帧都是先做word_embedding("我是啥语义"),然后加上positional_embedding("我在哪一帧"),然后使用 矩阵做线性变换,得到代表该帧作为key的向量. (Q/V 类似同理)
- 解码到帧时,对帧的attent程度为 ; 是点积形式, 两向量取值越接近,点积越大
-
和 分别都是
dense(word_embedding+positional_embedding)
的形式;若能在dense中学到rotate(-3)
的关系,即可使两个向量非常接近,点积很大
反之,如果直接使用下标作为位置的定义,
则"相对时间"的概念要通过 含有bias的"仿射变换(viz. 线性变换+平移)"才能表达
另外,如上文所述的“干支纪年法”类似;各个维度的周期不能整除,可以表征的范式是各维度的最小公倍数。
这样就能在inference时,对训练数据中从没见过的更长位置信息加以编码
下面给出 MXNet中位置编码的教学实现[4];工程实现类似,参见 GitHub
class PositionalEncoding(nn.Block):
def __init__(self, units, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
T = nd.arange(0, max_len).reshape((-1,1)) / nd.power(
10000, nd.arange(0, units, 2)/units) # 临时矩阵T
self.P = nd.zeros((1, max_len, units)) # 注意P是常数矩阵,无需训练
self.P[:, :, 0::2] = nd.sin(T) # 偶数下标行 取sin
self.P[:, :, 1::2] = nd.cos(T)
self.dropout = nn.Dropout(dropout)
def forward(self, X):
# P是常数矩阵,直接截取需要的形状部分 然后与X相加即可
X = X + self.P[:, :X.shape[1], :].as_in_context(X.context)
return self.dropout(X)
尚存几点疑惑
- 位置编码为啥要跟语义向量
add
到一起,而非concat
;上文中的旋转矩阵作用到语义部分,担心会有副作用 - 相对位置可以表达为线性变换,凭啥就有利于模型学习? paper原文中语焉不详,上述 K/Q/V的解释只是我的一种猜想,也不能很严谨的说服自己
- 更多欢迎留言讨论
Reference
-
"详解Transformer (Attention Is All You Need)", 刘岩, 2018, https://zhuanlan.zhihu.com/p/48508221 ↩
-
Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017. https://arxiv.org/abs/1706.03762 ↩
-
"Linear Relationships in the Transformer’s Positional Encoding" Timo Denk's Blog, Timo Denk, 2019, https://timodenk.com/blog/linear-relationships-in-the-transformers-positional-encoding/ ↩
-
"9.3 Transformer" Dive into Deep Learning, Mu Li, et al, http://en.d2l.ai/chapter_attention-mechanism/transformer.html ↩
网友评论