美文网首页
pytorch-lightning的一些记录

pytorch-lightning的一些记录

作者: IT_小马哥 | 来源:发表于2022-10-24 20:22 被阅读0次
    • 收集每个GPU上的输出
      在分布式训练时,每个GPU都会有一部分数据,当我们需要使用全部的数据进行计算时,我们需要收集所有GPU的tensor。
      比如两个GPU,第一个GPU有16组数据,第二个GPU有16组数据, 在进行对比学习计算时,我们需要收集所有的输出来增加负样本的数量。
      我们可以使用tensors_from_all = self.all_gather(my_tensor)
      比如:
        def training_step(self, batch, batch_idx):
            outputs = self(batch)
            ...
    
            all_outputs = self.all_gather(outputs, sync_grads=True)
    
            loss = contrastive_loss_fn(all_outputs, ...)
            return loss
    

    相关文章

      网友评论

          本文标题:pytorch-lightning的一些记录

          本文链接:https://www.haomeiwen.com/subject/qgwgzrtx.html