Wang Z, Hamza W, Song L. -Nearest Neighbor Augmented Neural Networks for Text Classification[J]. arXiv preprint arXiv:1708.07863, 2017.
摘要导读
近年来,许多基于深度学习的模型被用于文本分类。然而,在训练的过程中缺乏对训练集中实例级信息的利用。在本文中,作者建议通过利用输入文本的k-nearest neighbor(kNN)信息来加强神经网络模型对文本嵌入的学习以更好的辅助分类任务。具体来说,提出的模型采用了一个神经网络,将文本编码为嵌入表示。此外,该模型还利用输入文本的k-近邻作为外部存储器,并利用它来捕捉训练集中的实例级信息。最终的分类预测是基于神经网络编码器和kNN memory的特征进行的。实验结果显示,提出的模型在所有数据集上都优于基线模型,甚至在几个数据集上击败了29层的神经网络模型;并且在训练实例稀少和训练集严重不平衡的情况下也显示出卓越的性能;该模型甚至可以很好的利用在半监督训练和转移学习等技术中。
模型浅析
提出的模型主要从训练集中提取全局和实例级信息来进行文本分类任务。为捕捉全局信息,训练了一个神经网络编码器,根据所有的训练实例及其类别信息将文本编码到一个嵌入空间。为了捕捉实例级的信息,对于每个输入的文本,从训练集中搜索其对应的k个近邻样本,然后将其作为外部存储器来增强神经网络。
上图中蓝色的data flow是传统的文本分类流程。余下的部分即为本文提出的kNN memory,即使用注意力机制来抽取实例级别的信息用于预测。可以形式化为如下:给定样本和样本的kNN以及对应的正确标签和。因此本文的任务是基于训练集估计一个条件概率,然后可以用于测试样本的标签: 是所有可能标签的集合。- Text Encoder
该模块用于对文本进行编码,主要分为两个步骤:(1)词表示step将所有的词表示为词嵌入或字符嵌入;(2)句子表示step则是使用CNN或LSTM将词嵌入序列压缩到一个固定长度的文本嵌入。
在本模型中,使用LSTM对文本编码。首先在词表示阶段,作者为每个词构建了包含两个部分的维向量:词嵌入表示和字母组成的嵌入表示。词的嵌入是用GloVe或word2vec预训练的单独的词的固定向量。字符组成的嵌入是通过将一个词中的每个字符(也表示为一个向量)输入LSTM来计算的。在句子表示步骤,应用双向LSTM(BiLSTM)来组成单词表示序列,然后将BiLSTM最后一个时间步骤的两个向量(包括前向和后向)连接起来作为最终的文本嵌入。 - kNN Memory
kNN Memory是提出模型的核心部分,目标是为每个输入文本从其kNN中获取实例级信息。该部分包括以下六个步骤。
1、寻找kNN,对应图中黑色对应的data flow, 使用的BM25排序算法为每个样本选取k个近邻样本。
2、使用Text Encoder对近邻集合中的每个样本进行编码,对应图中黄色的data flow。
3、计算近邻关注分配,其目标是计算输入文本和嵌入空间中的每个K个邻居之间的相似性(邻居关注),对应途中灰色的data flow。令,样本和其第个近邻样本的embedding表示记录为和,都是由Text Encoder产生的维的隐含特征向量。理论上而言,所有的相似性度量函数都适用。本文作者适用的是multi-perspective cosine matching来计算两个向量之间的相似性 其中是一个可训练的参数矩阵,是控制perspective 数量的超参数,返回的是一个维的向量 每个元素表示在第个perspective中和之间的相似度,由两个向量之间的余弦相似度计算得到: 其中表示两个向量对应元素相乘,是矩阵中得第行,用于控制第个perspective,并为维文本嵌入空间的不同维度分配不同的权重。当时,一个样本与每个对应的邻居样本仅包含一个相似度值,上图中的neighbor attention就是的情况。
4、计算基于关注度的标签分布,基于近邻关注分配,权重化的将标签分布进行相加,如图中的绿色data flow。形式化来说,将文本对应的标签转化为one-hot标签分布,分别计算K个近邻在第个perspective对应的权重化的标签分布 最终的multi-perspective关注度的标签分布将个perspective的标签分布进行拼接。
5、计算基于关注度的文本嵌入表示,基于近邻关注分配,权重化的将文本嵌入进行相加,如图中的橙色data flow。类似的,第个perspective中K个近邻的文本嵌入表示为: 最终来自个perspective的表示即为。
6、可以说,前5步都是在为最终的分类任务学习更好的文本表示,有来自文本本身的,有来自文本对应K个近邻的标签和文本信息。即:
序号 | 符号表示 | 描述 |
---|---|---|
1 | 当前文本对应的text embedding | |
2 | 当前文本对应的kNN所生成的基于关注度的标签分布 | |
3 | 当前文本对应的kNN所生成的基于关注度的text embedding |
最后的步骤时将这些特征向量都拼接起来,用于学习分类预测。即将输入最终的分类器得出预测结果。在测试集上,则还是在训练集的样本中构造kNN memory,用于预测。
将训练集中的信息用的很全面。简单且有效。
网友评论