美文网首页
模型训练当中 checkpoint 作用是什么

模型训练当中 checkpoint 作用是什么

作者: luckriver | 来源:发表于2024-07-26 10:56 被阅读0次

最近在微调大语言模型的过程中发现训练时会在模型生成的目录出现很多checkpoint开头的文件夹,这些文件夹下面基本都是一套完整可用的模型文件,还比较占用空间。这里详细总结一下checkpoint 相关的使用。


训练中产生的检查点

checkpoint文件的来源

检查点(checkpoint)的概念最早出现在高性能计算领域,长时间运行的任务容易因为一些软硬件问题而失败。为了避免从头开始重新运行任务,才有了检查点的概念。在计算任务的某个时刻保存当前状态(称为检查点),如果任务中断,可以从最近的检查点恢复而不是重新开始。

深度学习领域为了应对训练过程中可能出现的中断,也采用了检查点技术。以 huggingface 的 transformer 库为例,假如采用如下训练代码,将会产生如上图所示的一系列检查点文件夹。

epochs = 10
lr = 2e-5
train_bs = 8
eval_bs = train_bs * 2

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=epochs,
    learning_rate=lr,
    per_device_train_batch_size=train_bs,
    per_device_eval_batch_size=eval_bs,
    evaluation_strategy="epoch",
    logging_steps=logging_steps
)

这里还会有一个疑问,检查点文件夹的命名规则是什么,结尾的数字可以看出都是 500 的倍数。

根据 huggingface 的文档,检查点的产生跟 TrainingArguments 的以下几个参数有关

  • save_strategy 决定了检查点保存的逻辑,有以下 3 个选项,默认为 steps
    • no 训练中不保存检查点
    • epoch 对每一个训练周期保存
    • steps 通过 save_steps 定义如何按训练步数保存
  • save_steps 两个检查点之间经历的训练步数,默认为 500 步。

按照上面训练代码的逻辑,由于这两个参数都没有制定,因此默认采用训练步数的方式保存检查点,并且每个 500 步就会保存一次。

最后还有一个问题,就是训练步数的计算,每处理一个 batch 数据并进行一次参数更新就算作一个 step,按照这个定义计算的话,总步数 = (样本数 / 批大小) * epochs。

我的样本数为 1624,批大小为 8,周期数为 10,带入公式计算总步数 = (1624 // 8) * 10 = 2030,这样也就可以解释为什么最后一个检查点的命名为 checkpoint-2000

checkpoint文件相关的使用方法

断点续训

检查点设计的初衷就是为了任务中断之后能够快速恢复,按照前面设定的逻辑,使用 transformer 库恢复训练的方法如下

# Trainer 的定义
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

# 从最近的检查点恢复训练
trainer.train(resume_from_checkpoint=True)

加载最好的模型

考虑到训练过程中发生的过拟合,常常需要选择在验证集上性能最好的模型,可通过如下设置load_best_model_at_end 达到自动选择的目的

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',            # 保存路径
    num_train_epochs=5,                # 训练周期数
    per_device_train_batch_size=32,    # 每个设备的训练batch大小
    evaluation_strategy="steps",       # 评估策略
    save_total_limit=3,                # 保留最近的3个检查点
    load_best_model_at_end=True,       # 在训练结束时加载验证集上最好的模型
    metric_for_best_model="accuracy",  # 用于选择最佳模型的指标
    greater_is_better=True             # 指标越高越好
)

其他

分布式检查点

对于分布式训练场景下的管理,参考微软推出的 DeepSpeed

相关文章

  • 2018-04-04

    TensorFlow 到底有几种模型格式? CheckPoint(*.ckpt)在训练 TensorFlow 模型...

  • tensorflow-3

    checkpoint 可以上手撸代码,明白建立网络、训练、评估测试的实现,常见模型:线性回归模型、softmax应...

  • tensorflow教程6:Supervisor长期训练帮手

    使用TensorFlow训练一个模型,可以多次运行训练操作,并在完成后保存训练参数的检查点(checkpoint)...

  • example

    预处理 构建模型 保存模型checkpoint 可视化 构建并训练 保存为h5 预测 验证 加载 重新验证

  • 2018-07-19

    sparkStreaming之checkPoint的作用解析 checkPoint的几大作用: 第一:如遇突发情况...

  • 吴恩达深度学习4.8 风格迁移的损失函数

    损失函数在深度学习当中的作用是评价模型的输出效果,一般来说,输出效果越好则损失函数的值越小。 在对模型进行训练时,...

  • 推荐系统学习

    训练集(Training set) 作用是用来拟合模型,通过设置分类器的参数,训练分类模型。后续结合验证集作用时,...

  • Checkpoint的作用

  • 2_checkpoints

    tensorflow提供两种模型格式 checkpoint:依赖于创建模型的代码 SavedModel:与模型代码...

  • 2018-12-23-tensorflow要点

    保存检查点(checkpoint) 为了得到可以用来后续恢复模型以进一步训练或评估的检查点文件(checkpoin...

网友评论

      本文标题:模型训练当中 checkpoint 作用是什么

      本文链接:https://www.haomeiwen.com/subject/kttvhjtx.html