文章名称
【WSDM-2021】【Google】Interpretable Ranking with Generalized Additive Models
核心要点
文章旨在解决ranking场景下,现有可解释模型精度不够的问题,提出将天生具有可解释性的广义加法模型(GAM)作为引入ranking场景,作为可解释排序模型。作者研究了如何将GAM应用到item和list级别的排序模型上,并利用神经网络而不是样条回归或回归树作为GAM排序模型。在此基础上利用蒸馏的方法,可以将神经排序模型蒸馏到更简单的分段函数。
上一节介绍了作者提出的ranking GAM思想(或者说解决问题的框架)以及其具体实例Neural Ranking GAM。本节继续介绍模型的训练以及蒸馏等操作。
方法细节
问题引入
上一节所述,提出在context-absent的场景下利用传统的GAM模型结合神经网络,计算单个物品的排序得分。并利用所有物品的排序得分得到最终的排序列表。在很多时候,可以利用插叙提供的丰富信息,然而直接把查询信息加入模型的线性加和部分,并不能帮助提升性能,甚至被消除掉,对最终评测指标也是没有意义的。
因此,作者提出ranking GAM,利用查询特征来学习GAM的权重,整体思路有点类似attention,只是为了避免复杂的逻辑,保持良好的可解释性,没有引入交互。作者给出了ranking GAM的实例模型,Neural Ranking GAM。那么如何训练模型呢?为了能够进一步提升,作者还提出了模型蒸馏的方案。
具体做法
首先,回顾一下背景问题的形式化定义,以及Neural Ranking GAM的框架。
framework of neural ranking GAM in context-present setting
- 数据集表示观测的数据集,整体观测数据集包括个样本。其中,分别表示查询query的向量(或者推荐的用户上下文向量),物品集合的特征矩阵(矩阵中的每一个向量表示一个物品的特征向量)以及物品和query的相关度标签(可以是0或1,也可以是表示相关性的有序列表)。
- 策略空间记作,而最优策略可以依据估计的相关性得分得到。
- 排序模型记作,最优策略可以通过在观测数据上训练模型来近似,
- 在文章的研究场景下,作者利用点估计函数,可以得到查询与物品相关性的估计值。如前所述,利用点估计的值,对物品进行排序,可以得到排序列表。
Loss
如前所述,ranking GAM不同于回归场景,其优化目标一般是ranking loss,文章利用模拟NDCG的损失函数来训练模型[3, 46],当然也可以采用MSE等模型,不过性能会打折扣。
Sub-module Distillation
为了能能够提升模型的运算速度,适应线上响应速度要求,作者提出利用模型蒸馏[20]的方式,对子模型(也就是每个特征的模型)进行蒸馏。具体的,作者采用piece-wise regression (也被称为segmented regression) [37],对数值型特征的模型进行蒸馏。
一般 piece-wise linear function,PWL由个结点决定,结点集合。其中每两个结点之间的因变量取值由两端的结点决定,具体公式如下图所示。
PWL function based on the knots的个数由实际场景决定,一般在3-5个左右。蒸馏过程中,选择最优的个结点,使得阶梯线性函数PWL得到的输出与该体征的输出的MSE最小,具体公式如下图所示。
distillation objectiveFitting Distillation
虽然通过分段回归代码库[38]来求解上述问题,但是,通常很慢并且无法扩展到大型数据集。 因此作者采用了贪心算法进行蒸馏训练。
如上所述,PWL的训练目标是得到结点集合。作者首先利用样本 的0%、1%、...、99%、100% 的百分位边界,生成一组节点候选,其中最多有101个元素。
基于这个候选集,训练目标变为从候选集中挑选个最优节点,, 的组合优化问题,。暴力解法是穷举每种可能,而作者优先利用贪心法生成初始集合(注意这里是子集合的初始集合),再不断地依据上述MSE优化,直到结果收敛。具体步骤如下,
- 构造初始集合,从中依次挑选个节点,来减小MSE。
- 遍历初始集合替换。从初始集合中寻找可以被中未用到的元素替换的结点。
上述贪心算法的具体流程参见伪代码部分。作者表示尽管是贪心算法,效果也是可以接受的。
代码实现
文章的伪代码如下图所示。
pseudo code to finding knots for PWL心得体会
蒸馏
模型蒸馏已经成了一种流行度后处理工程,这里作者利用了简单的模型来进行蒸馏,一方面进一步加快了运算速度,另一方面也提升了模型的可解释性。当然,可以采用其他模型来作为student模型,只要适用于目标场景,可以解释就好。
文章引用
[6] ChristopherJ.C.Burges.2010.FromRankNettoLambdaRanktoLambdaMART:
An Overview. Technical Report Technical Report MSR-TR-2010-82. Microsoft
Research.
[19] Trevor Hastie and Robert Tibshirani. 1986. Generalized Additive Models. Statist.
Sci. 1, 3 (1986), 297–318.
[20] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. 2015. Distilling the Knowledge in
a Neural Network. arXiv:stat.ML/1503.02531
[31] Yin Lou, Rich Caruana, and Johannes Gehrke. 2012. Intelligible models for
classification and regression. In KDD.
[37] Vito Muggeo. 2003. Estimating Regression Models with Unknown Break-Points. Statisticsinmedicine22(102003),3055–71. https://doi.org/10.1002/sim.1545
[38] Vito Muggeo. 2008. Segmented: An R Package to Fit Regression Models With Broken-Line Relationships. R News 8 (01 2008), 20–25.
[39] VinodNairandGeoffreyEHinton.2010.Rectifiedlinearunitsimproverestricted boltzmann machines. In ICML.
[51] Cynthia Rudin. 2019. Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead. Nature Machine Intelligence 1, 5 (2019), 206.
[53] Sofia Serrano and Noah A Smith. 2019. Is Attention Interpretable?. In ACL.
网友评论