美文网首页
XLA all reduce combiner pass 分析

XLA all reduce combiner pass 分析

作者: yxd886 | 来源:发表于2020-11-06 15:41 被阅读0次

    这个pass是hlo层对多个all reduce instruction判断是否需要进行合并的优化pass.也就是tensor fusion了。
    首先有一个结构体:

    using InstructionGroups =
        std::vector<std::vector<std::vector<HloInstruction*>>>;
    

    可以看到是三个vector的嵌套,乍一看不知道是干啥的,所以从创造他的函数CreateComputationGroups入手分析一下:
    这个函数首先遍历了一下computation的所有all reduce instruction.然后创建了一个 opcode_groups.

    std::map<GroupKey, std::vector<HloInstruction*>> opcode_groups;
    

    这个是对于不同all reduce的类型(sum, mean 等)分组,比较容易理解。
    接下来又基于opcode_groups创建了一个all_reduce_sets:

    std::map<int64, std::vector<std::pair<int64, HloInstruction*>>>
          all_reduce_sets;
      int64 group_id = 0;
      for (auto& domain_groups : opcode_groups) {
        for (HloInstruction* hlo : domain_groups.second) {
          all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo);
        }
        ++group_id;
      }
    

    对每一个op_code group按照遍历升序给了一个group_id。然后又按照all reduce instruction的 all_reduce_id分组。这次分组元素里不仅仅是instruction的指针了,而是group_id,instruction指针的pair.

    再紧接着创建了一个all_reduce的group map:

      std::map<std::vector<int64>, std::vector<std::vector<HloInstruction*>>>
          all_reduce_group_map;
      for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
        if (instruction->opcode() != HloOpcode::kAllReduce) {
          continue;
        }
        if (instruction->to_apply()->instruction_count() != 3 ||
            instruction->to_apply()->num_parameters() != 2) {
          VLOG(1) << "Skipping due to non-trivial reduction function.";
          continue;
        }
    
        int64 arid = channel_id(instruction);
        if (all_reduce_sets.count(arid) == 0) {
          // Already processed.
          continue;
        }
    
        std::vector<int64> group_ids;
        std::vector<HloInstruction*> instructions;
        for (const auto& hlo : all_reduce_sets[arid]) {
          group_ids.push_back(hlo.first);
          instructions.push_back(hlo.second);
        }
        all_reduce_group_map[group_ids].push_back(std::move(instructions));
        all_reduce_sets.erase(arid);
      }
      CHECK(all_reduce_sets.empty());
    

    这个map的key是group id的序列,value是instruction的指针的二维数组。最后那一维数组中的所有instruction都是属于同一个all_reduce id的。

    最后整个函数返回了 InstructionGroups groups;

      InstructionGroups groups;
      for (const auto& all_reduce_group : all_reduce_group_map) {
        groups.push_back(all_reduce_group.second);
      }
      return std::move(groups);
    }
    

    可以看到InstructionGroups这个结构体的实际含义是一个数组,数组的每个元素代表的是具有相同group_id序列的instruction组成的二维数组。二维数组的每一行的所有instruction都有相同的all_reduce id.
    每一列的所有instruction都有相同的group_id.

    相关文章

      网友评论

          本文标题:XLA all reduce combiner pass 分析

          本文链接:https://www.haomeiwen.com/subject/beptbktx.html