美文网首页
bert 模型压缩原理

bert 模型压缩原理

作者: 小小兰哈哈 | 来源:发表于2021-10-08 22:12 被阅读0次

    1. 压缩目的:

    在基本不影响模型效果的基础上,对bert模型进行同构压缩,将layer 与embedding size减少, 尽可能提升模型的性能。

    比较经典的压缩尺寸是 12 * 768 -> 6 * 384

    下面以classifier task为例子,讲一下bert模型压缩的原理和实现.

    classifier task的model的 结构:

     BERT --> MLP -->cross_entropy_loss

    2. 基本概念

    teacher model: 尺寸较大的模型, finetune model

    student model: 尺寸较小的模型,target model

    3. distillation loss的设计

    distillation可以分为两步。第一步,使用classifier task的label 训练teacher model,如果要做的精确一点,可同时训练student model的classifier 以及teacher的sequence attention 的logits和student 的sequence attention logits做交叉熵.

    loss1 -> grad -> loss2 -> grad -> loss3->grad

    第二步,将teacher model 的 parameters 做冻结,detach(), 使用MSE Loss的方式修正student model的Mlp logits的结果

    总结:第一步,主要实现teacher model的finetune和提高student的BERT layer与teacher BERT layer的sequence结果相关性

    第二步:实现student MLP logits 与teacher MLP logits 的相关性.

    实验证明可以基本实现在效果减小很少的情况下,性能有很大提升。

    第一步的具体的流程可表示为:

    1. teacher_sequence = teacher_sequence.detach() 做梯度冻结

      teacher_attention = torch.matmul(teacher_sequence , teacher_sequence.permute(0,2,1))

      input_mask = torch.unsqueeze(input_mask, 0) * torch.unsqueeze(input_mask, 1)

     将input_mask 也变成batch size * sequence * sequence的序列组合的形式.

    teacher_att = torch.log_softmax(teacher_attention) * input_mask [使用input_mask将原序列中需要编码忽略的部分置0, 必要的时候softmax前可以将相应的mask掉的部分的值调低)

    对student_sequence 采用同样的操作.

    att_loss = teacher_att * torch.log(student_att)/(torch.sum(input_mask))

    第二步的具体流程可表示为:

    teacher_logits = teacher_logits.detach()

    mse_loss = nn.MSE()(student_logits, teacher_logits)

    相关文章

      网友评论

          本文标题:bert 模型压缩原理

          本文链接:https://www.haomeiwen.com/subject/fclfoltx.html