美文网首页
论文阅读_知识蒸馏_Meta-KD

论文阅读_知识蒸馏_Meta-KD

作者: xieyan0811 | 来源:发表于2022-09-17 22:19 被阅读0次

    英文题目:Meta-KD: A Meta Knowledge Distillation Framework for Language Model Compression across Domains
    中文题目:Meta-KD:跨领域语言模型压缩的元知识蒸馏框架
    论文地址:http://export.arxiv.org/pdf/2012.01266v1.pdf
    领域:自然语言处理, 知识蒸馏
    发表时间:2020.12
    作者:Haojie Pan,阿里团队
    出处:ACL
    被引量:1
    代码和数据:https://github.com/alibaba/EasyNLP(集成于EasyNLP)
    阅读时间:2022-09-17

    读后感

    结合元学习和蒸馏学习:元学习使得模型获取调整超参数的能力,使其可以在已有知识的基础上快速学习新任务。

    介绍

    预训练的自然语言模型虽然效果好,但占空间大,预测时间长,使模型不能应用于实时预测任务。典型的方法是使用基于老师/学生模型的知识蒸馏。而模型一般面向单一领域,忽略了不同领域知识的知识转移。本文提出元蒸馏算法,致力于基于元学习的理论,让老师模型具有更大的转移能力,尤其对few-shot和zero-shot任务效果更好。

    如图-1所示,一个学物理的学生如果跟数学老师学习了数学方程知识,可能有助于他更好地理解物理方程。相近领域的数据可能提升模型的能力,但其它领域模型也可能转移一些无关的知识,从而影响性能。另外,当前研究证明:使用多任务精调也未必能提升所有任务的性能。由此,文中提出需要让老师模型消化不同领域的知识,并可针对具体领域,将知识转移到学生模型。在图-1(c)中,如果有万能的科学老师(元学习),它既会数学也会物理,则可以更好地教导学生。

    如图-2所示,模型包含两部分:元老师和元蒸馏:

    首先利用多领域数据集训练元老师,通过引入破坏域损失来获取跨域知识,然后针对具体领域,用领域相关数据集引导元老师,以提升学生的蒸馏能力。

    文章贡献

    • 第一次提出基于元学习的预训练自然语言模型压缩算法。
    • 提出Meta-KD框架训练跨领域的老师模型,包含元老师和元蒸馏两部分
    • 实验证明模型的有效性

    方法

    概览

    定义:设有K个领域的K个数据集参与训练,D为数据集,M为大模型,S为蒸馏后的学习模型。
    模型训练分为两个场景:

    • 训练一个学习了K个领域知识的元老师模型M,模型消化了各领域知识且有针对不同领域很好的泛化能力。
    • 在元蒸馏过程中,利用领域数据集DK和元模型M,训练学生模型SK。
      如果某一个领域的实例很少,如few-shot或zero-shot问题,通过知识转移训练该领域模型。

    元老师学习

    将BERT模型作为基础模型。

    基于原型实例加权
    学习过程中对每个实例X计算原型得分t,假设处理分类问题,共m个类别,计算所有第K领域中实例属于每个类别的概率均值(请参考图-3左侧的实心多边形):

    计算原型得分如下:

    此处cos用于计算相似度,α是超参数,公式的前半部分计算了该实体与它所在的领域的关系(在嵌入空间与同类实体的一致性),后半部分计算了与其它领域的关系。这样模型就同时学习了同一领域的知识和其它领域的知识。

    域破坏
    除了交叉熵损失,还加入了域破坏损失以提升元老师转移学习的能力。对于每个实例,学习一个与h维度相同的域嵌入,记作ED(epsilon D)。

    在BERT以外,又加入了一个子网络,对网络输出进一步处理:

    针对域破坏的损失函数定义为:

    其中σ(sigma)表示域类别,它是一个指示函数,只有0/1两个取值,这里最大化元教师对域标签做出错误预测的可能性。

    我理解,这里的损失函数是让实例最终能识别它所在的域类别k。

    损失函数
    最终的损失定义为:使用得分t加权针对所有领域的交叉损失;同时,加入了域破坏损失作为辅助,以训练模型转移知识的能力。

    这里的γ1(gamma)是超参数,用于设定域破坏损失的贡献。

    元蒸馏

    使用小型的BERT作为学生模型,蒸馏网络结构如图-3所示:

    目标由五个部分组成:输入嵌入Lembd,隐藏层状态Lhidn,注意力矩阵Lattn,输出ligit和知识转移。其中Lembd,Lhidn,Lattn的蒸馏方法与TinyBERT一样。又加入了Lpred对输出层使用软交叉熵损失。
    另外,考虑到特定领域的知识转移,下面公式又加入了域相关的损失:

    以此鼓励学生模型学习更多的该领域相关知识。我理解这里的hM是指对该领域的老师模型获得的编码。

    又引入λk参数,它是领域相关的权重:

    其中y^是预测的类别标签,当预测准确,或者t比较大时,λ值也相应变大,它反应的是老师在特定任务上监督学生的能力。

    整体蒸馏损失计算方法如下:

    实验

    使用自然语言推理(MNLI)和情绪分析(Amazon Reviews)两个任务评价模型。
    表-2和3展示了主实验结果:

    得出三个结论:

    • Meta-KD模型优于之前模型,它比基线模型小7.5倍,效果仅差0.5%
    • Meta-teacher模型效果很好,这表明元老师有能力学习更多可转移的知识来帮助学生。
    • 一般情况下,Meta-KD对小数据集数据效果更明显。

    图-4也说明在few-shot情况下,实例越少,Meta-KD效果越明显:

    相关文章

      网友评论

          本文标题:论文阅读_知识蒸馏_Meta-KD

          本文链接:https://www.haomeiwen.com/subject/xgweortx.html