美文网首页
PyTorch Lightning 中的批量测试及其存在的问题

PyTorch Lightning 中的批量测试及其存在的问题

作者: 莫底凯 | 来源:发表于2022-01-05 13:39 被阅读0次

    2022-1-5, Wed., 13:37 于鸢尾花基地
    可以采用如下方式对之前保存的预训练模型进行批量测试:

    for ckpt in ckpt_list:
        model = ptl_module.load_from_checkpoint(ckpt, args=args)
        trainer.test(model, dataloaders=test_dataloader)
    

    然而,在上述循环中,通过trainer.test每执行一次测试,都只是执行了一个epoch的测试(也就是执行多次ptl_module.test_step和一次ptl_module.test_epoch_end),而不可能把ckpt_list中的多个预训练模型(checkpoint)当做多个epoch,多次执行ptl_module.test_epoch_end

    我们期望,对多个checkpoint的测试能像对多个epoch的训练一样简洁:

    trainer.test(ptl_module, dataloaders=test_dataloader)
    

    怎么做到?在训练过程中,要训练多少个epoch是由参数max_epochs来决定的;而在测试过程中,怎么办?PTL并非完整地保存了所有epoch的预训练模型。

    由于在测试过程中对各checkpoint是独立测试的,如果要统计多个checkpoint的最优性能(如最大PSNR/SSIM),怎么办?这里的一个关键问题是如何保存每次测试得到的评估结果,好像PTL并未对此提供接口。

    解决方案
    PTL提供了“回调类(Callback)”(在 pytorch_lightning.callbacks 中),可以自定义一个回调类,并重载on_test_epoch_end方法,来监听ptl_module.test_epoch_end
    如何使用?只需要在定义trainer时,把该自定义的回调函数加入其参数callbacks即可:ptl.Trainer(callbacks=[MetricTracker()])。这里,MetricTracker为自定义的回调类,具体如下:

    class MetricTracker(Callback):
    
        def __init__(self):
            self.optim_metrics = None
    
        def on_test_epoch_end(self, trainer, pl_module):
            if self.optim_metrics is None:
                self.optim_metrics = pl_module.metrics_dict
                return
    
            tensorboard = pl_module.logger.experiment
            metrics_key_list, metrics_val_list = [], []
            for k in pl_module.metrics_dict:
                # comp_fun 是自己定义的比较函数
                self.optim_metrics[k] = comp_fun(self.optim_metrics[k], pl_module.metrics_dict[k])
    

    评论: 由于MetricTracker具有与Trainer相同的生命周期,因此,在整个测试过程中,MetricTracker能够维护一个最优的评估结果optim_metrics

    相关文章

      网友评论

          本文标题:PyTorch Lightning 中的批量测试及其存在的问题

          本文链接:https://www.haomeiwen.com/subject/zvihcrtx.html