参考文献:
https://github.com/DA-southampton/NLP_ability/blob/master/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E8%87%AA%E7%84%B6%E8%AF%AD%E8%A8%80%E5%A4%84%E7%90%86/%E6%A8%A1%E5%9E%8B%E8%92%B8%E9%A6%8F/bert2textcnn%E6%A8%A1%E5%9E%8B%E8%92%B8%E9%A6%8F.md
https://zhuanlan.zhihu.com/p/82129871
这里的蒸馏其实有点意思,什么是蒸馏,这词起的很有意思,比如如何从海水中炼得海盐。就是搞个火炉,去把海水蒸发掉,留下来的就是最精华的海盐。
那知识蒸馏就是把大模型的精华蒸馏出来,给小模型用。详细可以参考tinybert的实现。
这里简单讲讲如何将bert的精华蒸馏到textcnn等速度较快精度也较高的模型中。
简单来讲,就是同一个语句,输入bert得到一个logit,然后输入textcnn得到一个logit,两个logit之间做mse损失就可以了。其实就是想让我的textcnn学习到比较精准logit的结果,而不是简单的0-1,因为logit里面其实有很多的隐含知识,并不是最后简单的label信息。至于损失函数,可以如下:
对应再加上一些数据增强的措施(类bert操作),增强数据,防止过拟合,如:
Masking 使用[mask]标签来随机替换一个单词,例如“I love the comedy"替换为” I [mask] the comedy"。
POS-guided word replacement 将一个单词替换为相同POS标签的随机单词。例如,“What do pigs eat?"替换为"How do pigs eat?"。
n-gram sampling 随机采用n-gram,n从1到5,并丢弃其它单词。
具体的训练猜测:
同时训练两个模型,一种是bert的fineturn,一种是textcnn的学习,训练阶段的时候,可以同时获取两个模型的logit,然后计算mse损失,然后进行回传。这里的回传为了保证bert不受影响,可以不用回传到bert那侧,bert那边正常fineturn即可。
网友评论