1 前言
今天分享一篇2019年NIPS会议上一篇paper,方向为multi-label classification。论文题目为:AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification。论文下载链接为:https://arxiv.org/pdf/1811.01727.pdf,项目也开源出了代码:https://github.com/yourh/AttentionXML。
概要来说,本篇paper是提出一种基于Attention机制的label 树模型,来解决大规模多标签文本分类(Extreme multi-label text classification (XMTC))问题。研究出发点为:(1)先前的方法没有充分学习输入的文本与每个label之间的关系信息;(2)对大规模的label并没有进行一个可伸缩性的学习。针对两个问题,文中对应提出对应创新点:(1)引入multi-label attention mechanism for XMTC (AttentionXML ),为每个label捕捉最相关的特征信息;(2)引入probabilistic label tree (PLT)结构体系,处理百万级别的label集合。
2 Model
本篇论文的核心就是两个概念AttentionXML与PLT,就重点讲述这两个概念,先说下probabilistic label tree (PLT)。
2.1 Probabilistic Label Tree(PLT)
该概念是最先由K. Jasinska在2016年提出(Extreme f-measure maximization using sparse probability estimates.),解决extreme-scale datasets。该概念的想法是基于:在extreme labels集合中,是存在树状层级结构的,树结构可能是真实存在,也可能是潜在的(本文解决的方向)。所以像之前XML-CNN之类的方法把所有labels看成一个平行的结构来看待,这样导致所有的label都基于一个共同的表征向量来学习预测,没法差异性的学到与每个label最相关的信息。
Probabilistic Label Tree(PLT)的提出,就是利用概率的思想在label集合中构建一个Label Tree,基于树的结构去训练模型,预测label。构建的基本方法就是:通过每类的label文本,学到到该label表征向量,然后用递归聚类(KMeans)构建Label Tree,即叶子节点是一个真的标签,非叶子节点是一个虚拟标签。Parabel在2018年提出一种递归分裂聚类方式构建一颗二叉平衡树(主要意思是:每个节点下的子节点是有数量限制的,不能超过一个范围),在XMTC任务取得当时最好的效果。
本文就是Parabel的基础上提出了改进思路:认为Parabel方法构建的树深度H(不包括根节点和叶子节点)太深,而树的深度越深,label聚类错误可能性增大,训练预测效率也降低;此外,许多尾部标签与其他不同的标签组合在一起,并分组到一个集群中,损害了尾部标签的识别效果。所以,针对上述问题,本文提出一种方法,构建一个浅的(H很小)并且宽的PLT。
上图为一个PLT示意图,方形是树的叶子,代表所有的标签;圆形代表树的节点,是构建的伪标签;即L为label集合,H=2为去掉root与leaf树的高度,K=M=4代表KMeans聚类的K值和每个节点下面的最大容量。在这样的结构下,一个样本在每个节点(z_n)的概率有如下计算模式:
其中Pa(n)是节点n的父节点,Path(n)是从根节点到node n上路径的节点集合。
接着介绍本文是如何生成PLT:
具体的是,作者通过将每个标签下文本的BOW特征求和获得该标签的特征向量,然后通过一个K-Means循环将这些标签切分成两个cluster,直到每个节点下面的标签数小于M,这些cluster对应着树的内部节点。这里我理解的是,先完全按照二叉树的方式解析成一个树T_0,只是到最后一层的时候,按不超过M进行合并展开;接着按一种方式进行剪裁,将T_0树递归方式压缩到一个浅且宽的树T_h,如下图。这部分文中说的不是特别详细,如果想弄清楚,需去了解Parabel的paper。
上图显示的是构建一颗PLT树的过程,K=M=8,H=3,L=8000(M是max_leaf是最大叶子节点数,如果某个叶子里面的标签数超过8,就会切分该节点,H表示树高-2,L表示一共的标签数)。T0表示level=0的树,里面的数字表示树每个高度的节点数。Th中的红色数字对应的节点会被移掉,以为了获得Th+1 树。可以理解为:最后一层节点数10248>8000,已涵盖所有label了,而1288=1024>512>256,所以512,256这两层可以删掉;同样的方式,一直把树的深度裁剪到预定的H=3的高度,且每个节点的容量都不超过M=8。其中注意,root节点是不受M值限制的。**
下两图为PLT生成的伪代码和文本作者在三个数据集上生成的PLT情况:
2.2 AttentionXML
在构建好了PLT后,文中作者采用的是层级方式来训练模型的。具体包括:
(1)从上至下地给每个level单独的训练一个模型,每个模型都是一个多标签学习;
(2)level-d的树AttentionXML模型,是通过每个样本的候选标签g(x)训练的。我们对第(d-1)层AttentionXML模型预测的标签的从正到负,得分由高到底进行排序。我们选择第(d-1)层的top C标签作为下一层训练的候选标签g(x)。这就像是一种额外的负采样,相比于只使用节点的正样本,我们可以得到一个更精确的对数似然近似值。
(3)在预测阶段,对于第i个样本,第j个标签的预测得分y_i,j通过概率链式法则很容易获得。为了预测效率,我们使用beam search算法,对于第d层的,我们只预测d-1层top C的标签。
每层的AttentionXML模型结构如下图,主要包含5层:Word Representation Layer,Bidirectional LSTM Layer,Multi-label Attention Layer,Fully Connected Layer,Output Layer。
在Multi-label Attention Layer上,计算方式如下,就是常规的attention计算方式,让不同的label学习到与文本向量h_i不同的权重信息。
在Fully Connected Layer上,文中是采用各个层级模型共享的方式,主要目的减少模型复杂度。
预测阶段的伪代码3 总结
这次主要分享的目的就是在处理大规模多标签文本分类任务时,如何使用层级分类的思路解决该任务,提高识别效果。本文的实验结果就不分析了。文中提出的构建浅且宽的PLT树思路,可以借鉴,类似是一种折衷的方案,既不把label都视为平行结构,也不能把label构建成特别深的树结构,影响学习效率。但其中使用KMeans去构建PLT树,我内心是有点怀疑的,这应该会产生分类误差,没有一个衡量构建好坏或更合理的指标,不过目前我也没想到好的方法,留给大家去思考了......
更多文章可关注笔者公众号:自然语言处理算法与实践
网友评论