From bad016af5d50a9395b41dac93d8e42cbb4b6742a Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 14 Jun 2023 03:29:45 +0000 Subject: [PATCH 1/2] Delete unused codes --- cinn/hlir/pass/general_fusion_merge_pass.cc | 357 +------------------- 1 file changed, 5 insertions(+), 352 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index a47c37ce73..5e23df3041 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -41,8 +41,12 @@ using GroupList = std::vector; using OpGroupPtr = std::shared_ptr; using OpGroupList = std::vector; +<<<<<<< HEAD using ConditionFunction = std::function; +======= +class GraphGroupLightwareFusePassCtx; +>>>>>>> Delete unused codes class FuseHelper { public: virtual ~FuseHelper() = default; @@ -750,8 +754,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { public: GeneralFusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) { fusion_groups_ = graph->fusion_groups; - // init fusion relation. - InitFusionRelation(); // init input to consumers. InitInputToConsumers(); // init fusion group index. @@ -809,31 +811,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return updated; } - bool DoVerticalFusion(bool recompute) { - VLOG(3) << "DoVerticalFusion...!"; - bool updated = false; - for (int idx = 0; idx < fusion_groups_.size(); ++idx) { - auto producer = fusion_groups_[idx]; - VLOG(3) << "Fusion Producer Group -> " << producer->group_id; - // if producer is sub group. - if (producer->belong_groups.size()) { - continue; - } - // do horizontal fusion. - if (!recompute) { - updated |= HorizontalFusion(producer, producer->CollectConsumerGroups()); - } - updated |= VerticalFusion(producer, producer->CollectConsumerGroups(), recompute); - } - // fuse input consumers - updated |= FuseInputToConsumers(); - - if (updated) { - UpdateFusionGroup(); - } - return updated; - } - bool DoGeneralVerticalFusion() { VLOG(3) << "DoGeneralVerticalFusion...!"; bool updated = false; @@ -1033,76 +1010,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return true; } - bool HorizontalFusion(GroupPtr producer, const std::unordered_set& consumers) { - VLOG(3) << "HorizontalFusion...!"; - if (consumers.size() <= 1) { - return false; - } - - std::unordered_set candidates; - for (const auto& consumer : consumers) { - // relation - auto& relation = fusion_relation_map_[consumer->op_pattern_kind]; - // check horizontal relation exist - if (!relation.horizontal_relation.size()) { - continue; - } - candidates.insert(consumer); - } - - std::vector fusionable_consumers; - for (auto& candidate : candidates) { - // check dependency - if (IsDependencySimplify(producer, candidate, candidates)) { - VLOG(4) << "IsDependencySimplify, Can't fuse " << candidate->group_id << ", As it depency others!"; - continue; - } - - if (IsDependency(producer, candidate, candidates)) { - VLOG(4) << "IsDependency, Can't fuse " << candidate->group_id << ", As it depency others!"; - continue; - } - - if (!fusionable_consumers.size()) { - fusionable_consumers.push_back({candidate}); - continue; - } - - // check each fusionable groups - bool fusionable = false; - auto& relation = fusion_relation_map_[candidate->op_pattern_kind]; - for (auto& groups : fusionable_consumers) { - auto& last = groups.back(); - if (!relation.horizontal_relation.count(last->op_pattern_kind)) { - continue; - } - - if (!relation.horizontal_relation[last->op_pattern_kind](this, candidate, last)) { - continue; - } - - groups.push_back(candidate); - fusionable = true; - break; - } - - // if can't fuse to othors Groups, new Groups. - if (!fusionable) { - fusionable_consumers.push_back({candidate}); - } - } - - bool updated = false; - for (auto& groups : fusionable_consumers) { - if (groups.size() > 1) { - updated = true; - HorizontalFuse(groups); - } - } - - return updated; - } - void HorizontalFuse(const GroupList& consumers) { VLOG(3) << "HorizontalFuse Groups..."; // create fusion group @@ -1248,79 +1155,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { CHECK(fused_group->output_nodes.size()) << "No output node is found, " << fused_group->group_id; } - bool VerticalFusion(GroupPtr& producer, const std::unordered_set& consumers, bool recompute) { - VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size(); - auto& relation = fusion_relation_map_[producer->op_pattern_kind]; - // if producer can't fuse others - if (!relation.vertical_relation.size()) { - return false; - } - - std::unordered_set fuse_consumers_unsafe; - std::unordered_set fuse_consumers; - for (const auto& consumer : consumers) { - VLOG(4) << "Check consuemr " << consumer->group_id << " can fuse to producer " << producer->group_id; - // if can't fuse - if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; - continue; - } - - // if condition function is false - if (!relation.vertical_relation[consumer->op_pattern_kind](this, producer, consumer)) { - VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer " << consumer->group_id; - continue; - } - - fuse_consumers_unsafe.insert(consumer); - - if (IsDependencySimplify(producer, consumer, consumers)) { - VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id << " can't be master fused group!"; - continue; - } - - if (IsDependency(producer, consumer, consumers)) { - VLOG(4) << "IsDependency, Consumer " << consumer->group_id << " can't be master fused group!"; - continue; - } - - fuse_consumers.insert(consumer); - } - - VLOG(3) << "VerticalFusion, Number of fuse Consumers : " << fuse_consumers.size(); - VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : " << fuse_consumers.size(); - - if (fuse_consumers.size() == 0) { - return false; - } - // if can_fuse_consumers == consumers - // if producer op kind == kElementwise - // if use recompute - if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() && - producer->op_pattern_kind == framework::kElementWise) { - if (!recompute) { - return false; - } else { - RecomputeEleGraph(producer, fuse_consumers_unsafe); - VerticalFuse(producer, fuse_consumers_unsafe); - return true; - } - } - - if (fuse_consumers.size()) { - SelectConsumerToFuse(producer, fuse_consumers); - } - - // if fusionable consumers exist - if (fuse_consumers.size()) { - VerticalFuse(producer, fuse_consumers); - return true; - } - - return false; - } - - std::vector> RawVerticalFusePasses() const { + std::vector> RawVerticalFusePasses() const { return FusionPassMap::Instance().GetLightwareFusePassesByMode("VerticalFuse"); } @@ -1625,12 +1460,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { VerticalFuse(producer, fusionable_consumers); } - void RecomputeEleGraph(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { - if (producer->op_pattern_kind != framework::kElementWise) { - SelectConsumerToFuse(producer, fusionable_consumers); - } - } - void SelectConsumerToFuse(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { // if is const op if (is_const_group(this, producer)) { @@ -1727,86 +1556,6 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - bool IsDependency(const GroupPtr& producer_g, - const GroupPtr& consumer, - const std::unordered_set& consumers) { - std::queue candidates; - candidates.push(consumer); - - std::unordered_set visited_set; - while (!candidates.empty()) { - auto& candidate = candidates.front(); - candidates.pop(); - for (const auto& producer_and_list : candidate->producer_groups()) { - if (producer_and_list.first.get() == producer_g.get()) { - continue; - } - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); - if (consumers.count(producer)) { - return true; - } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); - } - } - } - return false; - } - - bool IsDependencySimplify(const GroupPtr& producer_g, - const GroupPtr& consumer, - const std::unordered_set& consumers) { - std::queue candidates; - candidates.push(consumer); - // check upper. - int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX; - std::unordered_set visited_set; - while (!candidates.empty()) { - auto& candidate = candidates.front(); - candidates.pop(); - for (auto& producer_and_list : candidate->producer_groups()) { - if (producer_and_list.first.get() == producer_g.get()) { - continue; - } - const auto& producer = std::dynamic_pointer_cast(producer_and_list.first); - if (producer->min_depth > check_upper_depth) { - continue; - } - if (consumers.count(producer)) { - return true; - } - if (!visited_set.count(producer)) { - visited_set.insert(producer); - candidates.push(producer); - } - } - } - return false; - } - - bool FuseInputToConsumers() { - VLOG(3) << "FuseInputToConsumers...!"; - auto updated = false; - UpdateInputToConsumers(); - GroupPtr producer(nullptr); - for (auto& input_consumers : input_to_consumers_) { - // if group set size == 1. - if (input_consumers.second.size() == 1) { - continue; - } - // do horizontal fusion. - auto st = HorizontalFusion(producer, input_consumers.second); - if (st) { - // fused consumers, update - UpdateInputToConsumers(); - } - updated |= st; - } - - return updated; - } - bool GeneralInputFuse() { VLOG(3) << "GeneralInputFuse...!"; auto updated = false; @@ -1922,105 +1671,9 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } - void InitFusionRelation() { - VLOG(3) << "InitFusionRelation...!"; - // kElementWise - { - auto& relation = fusion_relation_map_[OpPatternKind::kElementWise]; - // horizontal - relation.horizontal_relation = {{framework::kElementWise, is_same_size}, - // element-wise and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // element-wise and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // element-wise and reduce op must be horizontal relation. - {OpPatternKind::kReduction, honrizontal_elementwise_fuse_reduce}}; - // vertical - relation.vertical_relation = {{OpPatternKind::kElementWise, is_same_size}, - // element-wise and broadcast can be vertical/horizontal relation. - {OpPatternKind::kBroadcast, elementwise_fuse_broadcast}, - // element-wise and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // element-wise and reduce can be vertical/horizontal relation. - {OpPatternKind::kReduction, elementwise_fuse_reduce}}; - } - // kBroadcast - { - auto& relation = fusion_relation_map_[OpPatternKind::kBroadcast]; - // horizontal - relation.horizontal_relation = {// broadcast and element-wise op must be horizontal relation. - {framework::kElementWise, is_same_size}, - // broadcast and broadcast op must be horizontal relation. - {framework::kBroadcast, is_same_size}, - // broadcast and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // broadcast and reduce op must be horizontal relation. - {OpPatternKind::kReduction, is_same_size}}; - // vertical - relation.vertical_relation = {// broadcast and element-wise op must be vertical relation. - {OpPatternKind::kElementWise, is_same_size}, - // broadcast and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // broadcast and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // broadcast and reduce must be vertical relation. - {OpPatternKind::kReduction, broadcast_fuse_reduce}}; - } - // kInjective - { - auto& relation = fusion_relation_map_[OpPatternKind::kInjective]; - // horizontal - relation.horizontal_relation = {// injective and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, is_same_size}, - // injective and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // injective and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // injective and reduce must be horizontal relation. - {OpPatternKind::kReduction, is_same_size}}; - // vertical - relation.vertical_relation = {// injective and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, is_same_size}, - // injective and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // injective and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // injective and reduce can be horizontal/vertical relation. - {OpPatternKind::kReduction, injective_horizontal_with_reduce}}; - } - // kReduction - { - auto& relation = fusion_relation_map_[OpPatternKind::kReduction]; - // horizontal - relation.horizontal_relation = {// reduce and element-wise op must be horizontal relation. - {OpPatternKind::kElementWise, honrizontal_elementwise_fuse_reduce}, - // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, - // reduce and injective op must be horizontal relation. - {OpPatternKind::kInjective, is_same_size}, - // reduce and reduce must be horizontal relation. - {OpPatternKind::kReduction, reduce_fuse_reduce}}; - // vertical - relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. - {OpPatternKind::kElementWise, reduce_fuse_elementwise}, - // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, - // reduce and injective op must be horizontal relation. - {OpPatternKind::kInjective, horizontal_with_injective}, - // reduce and reduce must be horizontal relation. - {OpPatternKind::kReduction, reduce_fuse_reduce}}; - } - } - GroupList fusion_groups_; std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; - - struct Relation { - std::unordered_map vertical_relation; - std::unordered_map horizontal_relation; - }; - std::unordered_map fusion_relation_map_; }; void GeneralFusionMergePassInternal(Graph* graph) { From d09e2de8bfedd27b0d9baad717dfc876a06b21aa Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 16 Jun 2023 03:00:36 +0000 Subject: [PATCH 2/2] resolve conflict --- cinn/hlir/pass/general_fusion_merge_pass.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cinn/hlir/pass/general_fusion_merge_pass.cc b/cinn/hlir/pass/general_fusion_merge_pass.cc index 5e23df3041..f347e1f19b 100644 --- a/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -41,12 +41,7 @@ using GroupList = std::vector; using OpGroupPtr = std::shared_ptr; using OpGroupList = std::vector; -<<<<<<< HEAD -using ConditionFunction = std::function; - -======= class GraphGroupLightwareFusePassCtx; ->>>>>>> Delete unused codes class FuseHelper { public: virtual ~FuseHelper() = default;