美文网首页
keras 单机多卡情况下的知识蒸馏(knowledge dis

keras 单机多卡情况下的知识蒸馏(knowledge dis

作者: FreeTheWorld | 来源:发表于2022-02-16 10:42 被阅读0次

参考keras官方文档给出了知识蒸馏的简单模型的写法,但是在单机多卡或多机多卡开启了mirrored_strategy情况下该怎么写呢?
亲测在tensorflow2.4环境,成功的在开启mirrored_strategy下跑通模型,主要代码如下,重点讲解见注释。

Distiller模型

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        teacher_predictions = self.teacher(x, training=False)
        # 注意batch_size = total_batch_size / gpu个数
        with tf.GradientTape() as tape:
            student_predictions = self.student(x, training=True)
            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)   # shape = (batch_size,1)
            distillation_loss = self.distillation_loss_fn(teacher_predictions, student_predictions, self.temperature)  # shape = (batch_size,1)

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss  # shape = (batch_size,1) 
            loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y, sample_weight = data
        # Compute predictions
        student_out = self.student(x, training=False)
        y_prediction = student_out[0]

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

train 训练过程

def train(batch, train_dataset, test_dataset):
    mirrored_strategy = tf.distribute.MirroredStrategy(devices=gpus)

    batch_size = batch * mirrored_strategy.num_replicas_in_sync
    logging.info("batch size: %d, %d gpus", batch_size, mirrored_strategy.num_replicas_in_sync)
    with mirrored_strategy.scope():
        teacher = teacher_model()
    teacher.fit(train_dataset,
                    validation_data=test_dataset,
                    validation_steps=4,
                    epochs=4)

    with mirrored_strategy.scope():
        student = student_model()
        #这里是二分类loss,多分类可以换成softmax
        # loss的定义都要放到 with mirrored_strategy.scope()里
        binary_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE) #reduction参数一定要加上

        def distill_loss(teacher_predictions, student_predictions, sample_weight=None, temperature=None):
            if temperature and temperature != 1:
                teacher_predictions = tf.sigmoid(tf.math.log(teacher_predictions / (1 - teacher_predictions)) / temperature)
                student_predictions = tf.sigmoid(tf.math.log(student_predictions / (1 - student_predictions)) / temperature)

            per_example_loss = binary_loss(teacher_predictions, student_predictions, sample_weight=sample_weight)
            return per_example_loss

        distiller = Distiller(student=student, teacher=teacher)
        distiller.compile(
            optimizer=keras.optimizers.Adam(learning_rate=0.001, amsgrad=True),
            metrics=get_metrics(),
            student_loss_fn=binary_loss,
            distillation_loss_fn=distill_loss,
            alpha=ARGS.alpha,
            temperature=ARGS.temperature)

     distiller.fit(train_dataset,
                    validation_data=test_dataset,
                    validation_steps=4,
                    epochs=6)

相关文章

网友评论

      本文标题:keras 单机多卡情况下的知识蒸馏(knowledge dis

      本文链接:https://www.haomeiwen.com/subject/obrllrtx.html