美文网首页
Pytorch Lightning系列 如何使用ModelChe

Pytorch Lightning系列 如何使用ModelChe

作者: 四碗饭儿 | 来源:发表于2021-08-30 20:23 被阅读0次

    在训练机器学习模型时,经常需要缓存模型。ModelCheckpoint是Pytorch Lightning中的一个Callback,它就是用于模型缓存的。它会监视某个指标,每次指标达到最好的时候,它就缓存当前模型。Pytorch Lightning文档 介绍了ModelCheckpoint的详细信息。

    我们来看几个有趣的使用示例。

    示例1 注意,我们把epoch和val_loss信息也加入了模型名称。

    >>> checkpoint_callback = ModelCheckpoint(
    ...     monitor='val_loss', #我们想要监视的指标 
    ...     dirpath='my/path/',  #模型缓存目录
    ...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' # 模型名称
    ... )
    

    示例2 这个使用例子非常像示例1,唯一的差别在于指标的名称是由我们自己指定的,而不是由Pytorch Lightning自动生成的 (auto_insert_metric_name=False)。通过这样的方式,我们可以使用类似val/mrr的指标名。从而统一tensorboard和pytorch lightning对指标的不同描述方式。

    >>> checkpoint_callback = ModelCheckpoint(
    ...     monitor='val/loss',
    ...     dirpath='my/path/',
    ...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', # 注意到val/loss变成了val_loss
    ...     auto_insert_metric_name=False
    ... )
    

    Pytorch Lightning把ModelCheckpoint当作最后一个CallBack,也就是它总是在最后执行。这一点在我看来很别扭。如果你在训练过程中想获得best_model_score或者best_model_path,它对应的是上一次模型缓存的结果,而并不是最新的模型缓存结果

    self.trainer.checkpoint_callback.best_model_score
    

    相关文章

      网友评论

          本文标题:Pytorch Lightning系列 如何使用ModelChe

          本文链接:https://www.haomeiwen.com/subject/euboiltx.html