自蒸馏整体网络结构:
network
其中,bottleneck可减轻每个浅分类器之间的影响,添加teacher隐藏层L2 loss,并且使teacher与student网络feature map输出大小一致。
三个损失函数:
- 交叉熵损失(从标签到最深分类器和浅分类器):根据数据集标签与分类器softmax输出进行计算
- KL散度:计算teacher与student 之间的softmax
- L2 loss:计算最深分类器与浅分类器feature map 之间的 L2 loss
总体损失:
C表示CNN中分类器个数
其中,最深分类器的λ和α为零,即最深分类器的监督仅来自标签。
注意
- 自蒸馏存在梯度消失的问题,因此较深的神经网络较难训练
- 自蒸馏一种提高模型性能的训练技术,而不是一种压缩模型的方法
网友评论