美文网首页
[PyTorch]多卡运行(transformer-xl)

[PyTorch]多卡运行(transformer-xl)

作者: VanJordan | 来源:发表于2019-03-31 15:26 被阅读0次

    原理

    • GPU运行的接口是torch.nn.DataParallel(module, device_ids)其中module 参数是所要执行的模型,而 device_ids 则是指定并行的 GPU id 列表。
    • 而其并行处理机制是,首先将模型加载到主 GPU 上,然后再将模型复制到各个指定的从 GPU 中,然后将输入数据按batch 维度进行划分,具体来说就是每个 GPU 分配到的数据 batch 数量是总输入数据的 batch 除以指定 GPU 个数。每个 GPU 将针对各自的输入数据独立进行 forward 计算,最后将各个 GPUloss 进行求和,再用反向传播更新单个 GPU 上的模型参数,再将更新后的模型参数复制到剩余指定的 GPU 中,这样就完成了一次迭代计算。所以该接口还要求输入数据的 batch 数量要不小于所指定的 GPU 数量。
    • DataParallel自动地分割输入数据,同时将他们发送到每个GPU的模型中. 当模型处理完成后,DataParallel会将各个设备中的处理结果收集和合并,再返回给用户。
      示意图

    需要注意

    • GPU 默认情况下是 0GPU,也可以通过 torch.cuda.set_device(id) 来手动更改默认 GPU
    • 提供的多 GPU 并行列表中需要包含有主 GPU
    • 但是,DataParallel 有一个问题:GPU 使用不均衡。在一些设置下,主GPU 会比其他 GPU 使用率高得多。

    例子

    • 构建多GPUDataParallel
    if args.multi_gpu:
        model = model.to(device)
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
                                              model, dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model.to(device)
    
    • 在正向传播的时候使用para_model,其他的时候,比如使用模型的参数可以直接调用model.parameters()

    • 均衡的DataParallel

    class BalancedDataParallel(DataParallel):
        def __init__(self, gpu0_bsz, *args, **kwargs):
            self.gpu0_bsz = gpu0_bsz
            super().__init__(*args, **kwargs)
    
        def forward(self, *inputs, **kwargs):
            if not self.device_ids:
                return self.module(*inputs, **kwargs)
            if self.gpu0_bsz == 0:
                device_ids = self.device_ids[1:]
            else:
                device_ids = self.device_ids
            inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
            if len(self.device_ids) == 1:
                return self.module(*inputs[0], **kwargs[0])
            replicas = self.replicate(self.module, self.device_ids)
            if self.gpu0_bsz == 0:
                replicas = replicas[1:]
            outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
            return self.gather(outputs, self.output_device)
    
        def parallel_apply(self, replicas, device_ids, inputs, kwargs):
            return parallel_apply(replicas, inputs, kwargs, device_ids)
    
        def scatter(self, inputs, kwargs, device_ids):
            bsz = inputs[0].size(self.dim)
            num_dev = len(self.device_ids)
            gpu0_bsz = self.gpu0_bsz
            bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
            if gpu0_bsz < bsz_unit:
                chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
                delta = bsz - sum(chunk_sizes)
                for i in range(delta):
                    chunk_sizes[i + 1] += 1
                if gpu0_bsz == 0:
                    chunk_sizes = chunk_sizes[1:]
            else:
                return super().scatter(inputs, kwargs, device_ids)
            return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
    

    相关文章

      网友评论

          本文标题:[PyTorch]多卡运行(transformer-xl)

          本文链接:https://www.haomeiwen.com/subject/xcxubqtx.html