Klein T, Nabi M. miCSE: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings[J]. arXiv preprint arXiv:2211.04928, 2022.
摘要导读
本文提出了一个基于互信息的对比学习框架miCSE,显著地提高了在少量句子嵌入方面的先进水平。该方法在对比学习过程中调整了不同视图之间的注意力模式。通过miCSE学习句子嵌入需要加强每个句子的增强视图的结构一致性,使对比自监督学习提高了效率。因此,在小样本领域该方法取得了很好的效果,在全样本的场景下依旧适用。
本文的贡献如下:
- 通过添加一个attention-level的目标,将结构性信息引入到语言模型中。
- 引入了注意力互信息(AMI),一种可以提高样本效率的自监督对比学习方法。
方法浅析
该方法旨在在对比学习的方案中利用到句子结构信息。与传统的仅在嵌入空间中的语义相似度水平上进行操作的对比学习相比,该方法在模型中注入了结构信息。这是通过在训练过程中正则化模型的注意力空间来实现的。
符号声明
给定字符串语料库,其对应的数据集表示为,表示含有个token的序列。在对句子映射时,本文采用的是bi-encoder--,其输入是输入句子的不同类型的增强表示。这里使用作为增强视图表示的索引值。因此,对batch_size为进行编码,会得到嵌入矩阵,是嵌入表示的维度。使用Transformer的话,对应于产生的还有其相关联的注意力矩阵。因此,提出的模型联合优化如下损失,以达到对语义和结构的一致性学习:
显然,损失函数中的第一项是语义的对齐,在嵌入表示空间中使用传统的InfoNCE来实现;第二项是在注意力空间中对句法的对齐,不同的是,句法的对齐仅关注正例。Embedding-level Momentum-Contrastive Learning (InfoNCE)
InfoNCE loss试图在嵌入空间中将正例对拉在一起,同时将负例对分开。具体来说,嵌入的InfoNCE推动每个样本以及相应的增强嵌入表示之间的相似性。对应的损失函数如下:
其中,分别是对应于的两种不同的嵌入表示。,为余弦相似度。显然,负例的构成包含两种形式:(1)给定batch中除了当前样本之外的样本;(2)存储在中的前序batch中的嵌入表示(这是动量编码器中扩充负样本的常用操作)。Attention-level Mutual Information (AMI)
首先关于Transformer中的注意力机制的详情这里不再进行赘述。只需要知道在每个注意力头中包含三个矩阵,,,其中,进行运算会得到注意力权重矩阵,其中表示缩放点积。最终注意力头的输出为。当然在实际的应用中为了得到不同的嵌入子空间,一般会将该注意力操作重复次,被称为多头注意力机制。在训练编码器的过程中,自注意力张量 的值会受到随机确定性过程的影响,这种随机性由dropout操作产生。因此,基于结构信息的对齐,则是要最大化之间的互信息。本文通过四个步骤来正则化注意力空间。
- Attention Tensor Slicing
Attention tensor slicing 在输入上实例化一个Transformer栈会产生一个注意张量,其中包含Transformer的层数和自注意力的头数。显然,这里将打平成1维的张量是为了在三维空间中更好的展示tensor的shape。因此,对于的某个自注意力头来说,其输出是。考虑到多层和多头,其对应的张量就变成了。
slicing函数的主要作用是将每个输入样本的注意力张量切分为个元素: 即: 其中对于,每个元素。如果token个数小于,则通过填充对不同长度的序列进行补齐,以适应批处理。(这种切分方式的好处是还是保留了句子中个token之间的相关关系。只是不知道的设定是否会对性能产生大的影响。) - Attention Sampling
虽然[PAD]的填充会使得在GPU上可以进行有效的批处理编码,但在查看相关关系时,需要放弃token对[PAD]标记产生的注意力得分。为了适应不同长度的标记化序列,对每个网格单元内的注意得分执行采样操作。采样使用的是多项式分布 对于有值的,(),构造的注意力得分池,每个得分被等概率采样,其余的则概率为。对于每个对应的表示中,采样个注意力得分构成集合 注:这里忽略了样本下标。具体来说,由以下多项式分布产生 因此,对于同一个切分元素而言,会使用相同的采样索引。
不同增强视图的采样后表示如下: - Attention Mutual Information Estimation
文章提出使用互信息来衡量不同视图下注意力模式的相似性。具体来说,采用对数正态分布对注意力得分分布进行建模。(torch.Tensor.log_normal_())
两个正态分布的元组向量的互信息可以写为其相关性函数: ρ对应于由和计算出的相关系数。
因此,对于给定样本的第的切分元素而言,其对应的互信息可以写成 这里的函数,用于实现从 Log-Normal到Normal随机变量的转换。对应的实现细节如下: 实现的伪代码中有三点需要注意:
(1)在切分注意力tensor的时候,给出的伪代码缺少了对的切分步骤;
(2)在采样的时候就已经是将的打平成1维的进行操作了;
(3)采样过后的值,应该就是将的变成了维的向量,然后进行后续的运算,最终返回互信息值的大小。 - Mutual Information Aggregation
为了计算注意力正则化的损失分量,需要聚合整个张量的分布相似性。聚合是通过对批次中每个切片和每个样本的个体相似性进行平均得到的。给定权重缩放因子,权重对齐损失写成如下形式:
相关实验设置
(1) InfoNCE中另一种负例的设置:The momentum encoder is associated with a sample queue of size |Q| = 384.
(2)切分中,关于切分个数以及采样个数的设置:From each of the (4 × (H/2)) chunks of pooled attentions, we random sample 150 joint-attention pairs for each embedding of the bi-encoder.
实验设置中关于切分的部分,看上去是将层数每4个分为一组,而在注意力头数上则是被分为了两组。
想法真的是很出彩,一般来说,大家都考虑引入额外的网络结构通过语义的嵌入表示来获取结构信息以达到语义和结构的一致性。本文作者利用自注意力计算中的中间步骤作为结构信息从而在不增加额外网络结构的情况下使得PLMs引入了结构信息。
但也会存在一定的疑问,为什么一定要进行切分?直接对每个头每个层的注意力权重进行对比不是更加方便吗?除了减少计算量,是否还有别的说法?如果是想让不同的Transformer和不同的Multi-head之间产生交互的话,采用滑动窗口的方式是不是更好?
网友评论