美文网首页
模型训练当中 checkpoint 作用是什么

模型训练当中 checkpoint 作用是什么

作者: luckriver | 来源:发表于2024-07-26 10:56 被阅读0次

    最近在微调大语言模型的过程中发现训练时会在模型生成的目录出现很多checkpoint开头的文件夹,这些文件夹下面基本都是一套完整可用的模型文件,还比较占用空间。这里详细总结一下checkpoint 相关的使用。


    训练中产生的检查点

    checkpoint文件的来源

    检查点(checkpoint)的概念最早出现在高性能计算领域,长时间运行的任务容易因为一些软硬件问题而失败。为了避免从头开始重新运行任务,才有了检查点的概念。在计算任务的某个时刻保存当前状态(称为检查点),如果任务中断,可以从最近的检查点恢复而不是重新开始。

    深度学习领域为了应对训练过程中可能出现的中断,也采用了检查点技术。以 huggingface 的 transformer 库为例,假如采用如下训练代码,将会产生如上图所示的一系列检查点文件夹。

    epochs = 10
    lr = 2e-5
    train_bs = 8
    eval_bs = train_bs * 2
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        learning_rate=lr,
        per_device_train_batch_size=train_bs,
        per_device_eval_batch_size=eval_bs,
        evaluation_strategy="epoch",
        logging_steps=logging_steps
    )
    

    这里还会有一个疑问,检查点文件夹的命名规则是什么,结尾的数字可以看出都是 500 的倍数。

    根据 huggingface 的文档,检查点的产生跟 TrainingArguments 的以下几个参数有关

    • save_strategy 决定了检查点保存的逻辑,有以下 3 个选项,默认为 steps
      • no 训练中不保存检查点
      • epoch 对每一个训练周期保存
      • steps 通过 save_steps 定义如何按训练步数保存
    • save_steps 两个检查点之间经历的训练步数,默认为 500 步。

    按照上面训练代码的逻辑,由于这两个参数都没有制定,因此默认采用训练步数的方式保存检查点,并且每个 500 步就会保存一次。

    最后还有一个问题,就是训练步数的计算,每处理一个 batch 数据并进行一次参数更新就算作一个 step,按照这个定义计算的话,总步数 = (样本数 / 批大小) * epochs。

    我的样本数为 1624,批大小为 8,周期数为 10,带入公式计算总步数 = (1624 // 8) * 10 = 2030,这样也就可以解释为什么最后一个检查点的命名为 checkpoint-2000

    checkpoint文件相关的使用方法

    断点续训

    检查点设计的初衷就是为了任务中断之后能够快速恢复,按照前面设定的逻辑,使用 transformer 库恢复训练的方法如下

    # Trainer 的定义
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset
    )
    
    # 从最近的检查点恢复训练
    trainer.train(resume_from_checkpoint=True)
    
    

    加载最好的模型

    考虑到训练过程中发生的过拟合,常常需要选择在验证集上性能最好的模型,可通过如下设置load_best_model_at_end 达到自动选择的目的

    from transformers import Trainer, TrainingArguments
    
    training_args = TrainingArguments(
        output_dir='./results',            # 保存路径
        num_train_epochs=5,                # 训练周期数
        per_device_train_batch_size=32,    # 每个设备的训练batch大小
        evaluation_strategy="steps",       # 评估策略
        save_total_limit=3,                # 保留最近的3个检查点
        load_best_model_at_end=True,       # 在训练结束时加载验证集上最好的模型
        metric_for_best_model="accuracy",  # 用于选择最佳模型的指标
        greater_is_better=True             # 指标越高越好
    )
    
    

    其他

    分布式检查点

    对于分布式训练场景下的管理,参考微软推出的 DeepSpeed

    相关文章

      网友评论

          本文标题:模型训练当中 checkpoint 作用是什么

          本文链接:https://www.haomeiwen.com/subject/kttvhjtx.html