这个pass是hlo层对多个all reduce instruction判断是否需要进行合并的优化pass.也就是tensor fusion了。
首先有一个结构体:
using InstructionGroups =
std::vector<std::vector<std::vector<HloInstruction*>>>;
可以看到是三个vector的嵌套,乍一看不知道是干啥的,所以从创造他的函数CreateComputationGroups
入手分析一下:
这个函数首先遍历了一下computation的所有all reduce instruction.然后创建了一个 opcode_groups.
std::map<GroupKey, std::vector<HloInstruction*>> opcode_groups;
这个是对于不同all reduce的类型(sum, mean 等)分组,比较容易理解。
接下来又基于opcode_groups创建了一个all_reduce_sets:
std::map<int64, std::vector<std::pair<int64, HloInstruction*>>>
all_reduce_sets;
int64 group_id = 0;
for (auto& domain_groups : opcode_groups) {
for (HloInstruction* hlo : domain_groups.second) {
all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo);
}
++group_id;
}
对每一个op_code group按照遍历升序给了一个group_id。然后又按照all reduce instruction的 all_reduce_id分组。这次分组元素里不仅仅是instruction的指针了,而是group_id,instruction指针的pair.
再紧接着创建了一个all_reduce的group map:
std::map<std::vector<int64>, std::vector<std::vector<HloInstruction*>>>
all_reduce_group_map;
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kAllReduce) {
continue;
}
if (instruction->to_apply()->instruction_count() != 3 ||
instruction->to_apply()->num_parameters() != 2) {
VLOG(1) << "Skipping due to non-trivial reduction function.";
continue;
}
int64 arid = channel_id(instruction);
if (all_reduce_sets.count(arid) == 0) {
// Already processed.
continue;
}
std::vector<int64> group_ids;
std::vector<HloInstruction*> instructions;
for (const auto& hlo : all_reduce_sets[arid]) {
group_ids.push_back(hlo.first);
instructions.push_back(hlo.second);
}
all_reduce_group_map[group_ids].push_back(std::move(instructions));
all_reduce_sets.erase(arid);
}
CHECK(all_reduce_sets.empty());
这个map的key是group id的序列,value是instruction的指针的二维数组。最后那一维数组中的所有instruction都是属于同一个all_reduce id的。
最后整个函数返回了 InstructionGroups groups;
InstructionGroups groups;
for (const auto& all_reduce_group : all_reduce_group_map) {
groups.push_back(all_reduce_group.second);
}
return std::move(groups);
}
可以看到InstructionGroups这个结构体的实际含义是一个数组,数组的每个元素代表的是具有相同group_id序列的instruction组成的二维数组。二维数组的每一行的所有instruction都有相同的all_reduce id.
每一列的所有instruction都有相同的group_id.
网友评论