CVPR 2019 | Variational Information Distillation for Knowledge Transfer
https://github.com/qiu931110/RepDistiller
1.互信息
在这篇论文中,作者提出了一种新的知识蒸馏形式,该方法将知识蒸馏的最优性能定义为最大化教师和学生网络之间的互信息。那么为什么通过最大化互信息可以使得蒸馏学习变得有效呢?首先作者对互信息做了如下定义:
如上述公式所述,互信息为[教师模型的熵值] - [已知学生模型的条件下的教师模型熵值]。而我们又有如下常识:当学生模型已知,能够使得教师模型的熵很小,这说明学生模型以及获得了能够恢复教师模型所需要的“压缩”知识,间接说明了此时学生模型已经学习的很好了。而这种情况下也就是说明上述公式中的H(t|s)很小,从而使得互信息I(t;s)会很大。作者从这个角度解释了为什么可以通过最大化互信息的方式来进行蒸馏学习。
2.蒸馏过程详解
如下图所示,由于p(t|s)难以计算,作者根据文献The IM algorithm: a variational approach to information maximization. 2004.提出的IM算法,利用一个可变高斯q(t|s)来模拟p(t|s),下述公式中的大于等于操作用到了KL散度的非负性。由于蒸馏过程中H(t)和需要学习的学生模型参数无关,因此最大化互信息就转换为最大化可变高斯分布的问题。
如下公式所示,作者利用一个均值,方差可学习的高斯分布来模拟上述的q(t|s)。
如下代码所示,作者通过一个卷积小网络来模拟可变均值,并加上relu操作增强可变均值的非线性能力。
self.regressor = nn.Sequential(
conv1x1(num_input_channels, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_target_channels),
)
pred_mean = self.regressor(input)
并利用如下公式构建可学习的方差,其中阿尔法c是可学习参数。
self.log_scale = torch.nn.Parameter(
np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
)
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
最终整个蒸馏过程如下图所示,学生网络除了学习自身任务的交叉熵损失外,同时与教师网络保持高互信息(MI),通过学习并估计教师网络中的分布,激发知识的传递,使相互信息最大化。
3.结果展示
在这篇文章中,作者提出了通过最大化两个神经网络之间相互信息的变分下界来实现有效知识转移的VID框架。算法是基于高斯观测模型实现的,如下结果表明,在蒸馏学习方面,算法性能优于其他基准。实话说这个算法的数学性太强了!虽然读了两遍,也把代码复现到业务中了,但对内部的细节还是没有摸得太透,后续需要把IM算法精度一遍,才有可能真正理解变分分布的概念。
网友评论