From 5b7dc57bc48ac3a99c2f1c20ba79099480b09be0 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Sun, 10 Mar 2024 08:35:11 +0000 Subject: [PATCH] add group_pattern_util.InferShardableAxesFromSink --- paddle/cinn/frontend/group_pattern_util.cc | 55 +++++++++++++--------- paddle/cinn/frontend/group_pattern_util.h | 6 ++- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index ba146aa0dbd07..c5660222cf0af 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -239,9 +239,8 @@ std::unordered_map ReversedInferShardableAxes( return value2shardable_axes; } -common::TopoWalker GetOpsTopoWalker(const std::vector& ops) { - using Cache = std::unordered_set; - auto ops_set = std::make_shared(ops.begin(), ops.end()); +common::TopoWalker GetOpsTopoWalker(const std::unordered_set& ops) { + const auto* ops_set = &ops; const auto VisitUpStreamInOps = [ops_set](const pir::Operation* op, const OpVisitor& DoEach) { VisitInputOp(op, [&](const auto* input){ if (ops_set->count(input) == 0) return; @@ -258,21 +257,26 @@ common::TopoWalker GetOpsTopoWalker(const std::vector GetStarts( - const common::TopoWalker& topo_walker, - const std::vector& ops) { - const auto IsStart = [&](const pir::Operation* op) { - size_t num_prevs = 0; - topo_walker.VisitPrevNodes(op, [&](const auto*){ ++num_prevs; }); - return num_prevs == 0; +std::list GetSinks( + const std::unordered_set& ops) { + const auto IsSink = [&](const pir::Operation* op) { + for (int i = 0; i < op->num_results(); ++i) { + pir::Value output = op->result(i); + for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) { + const auto* consumer_op = consumer_it->owner(); + if (consumer_op->isa()) continue; + if (ops.count(consumer_op) > 0) return false; + } + } + return true; }; - std::list starts; + std::list sinks; for (const auto* op : ops) { - if (IsStart(op)) { - starts.push_back(op); + if (IsSink(op)) { + sinks.push_back(op); } } - return starts; + return sinks; } class StmtFusionHelper { @@ -617,17 +621,12 @@ class StmtFusionHelper { ShardableAxesSignature GetShardableAxesSignature(const std::vector& ops) const { std::unordered_set ops_set(ops.begin(), ops.end()); - auto reversed_walker = GetOpsTopoWalker(ops); const pir::Operation* sink = [&]{ - const auto& sinks = GetStarts(reversed_walker, ops); + const auto& sinks = GetSinks(ops_set); CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node."; return *sinks.begin(); }(); - const auto& value2shardable_axes = [&]{ - size_t rank = GetRank(sink->result(0)); - const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank); - return ReversedInferShardableAxes(reversed_walker, sink, init_sa); - }(); + const auto& value2shardable_axes = InferShardableAxesFromSink(sink, ops_set); const auto& IsInputOpOperand = [&](const auto* op, int input_idx) { const auto& defining_op = op->operand_source(input_idx).defining_op(); return IsInThisFusionOp(defining_op) && ops_set.count(defining_op) == 0; @@ -678,10 +677,20 @@ GroupPattern GenerateGroupPatternFromFusionOp(const cinn::dialect::FusionOp& fus return FuseToGroupPattern(fusion_op); } -std::unordered_map InferShardableAxes(const std::vector& ops) { +std::unordered_map InferShardableAxesFromSink( + const pir::Operation* sink, + const std::unordered_set& ops) { + auto reversed_walker = GetOpsTopoWalker(ops); + CHECK_GT(ops.count(sink), 0); + size_t rank = GetRank(sink->result(0)); + const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank); + return ReversedInferShardableAxes(reversed_walker, sink, init_sa); +} + +std::unordered_map InferShardableAxes(const std::unordered_set& ops) { auto reversed_walker = GetOpsTopoWalker(ops); const pir::Operation* sink = [&]{ - const auto& sinks = GetStarts(reversed_walker, ops); + const auto& sinks = GetSinks(ops); CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node."; return *sinks.begin(); }(); diff --git a/paddle/cinn/frontend/group_pattern_util.h b/paddle/cinn/frontend/group_pattern_util.h index da46b2be050af..2b5f96b9c653f 100644 --- a/paddle/cinn/frontend/group_pattern_util.h +++ b/paddle/cinn/frontend/group_pattern_util.h @@ -6,6 +6,10 @@ namespace cinn::frontend { GroupPattern GenerateGroupPatternFromFusionOp(const cinn::dialect::FusionOp&); -std::unordered_map InferShardableAxes(const std::vector& ops); +std::unordered_map InferShardableAxes(const std::unordered_set& ops); + +std::unordered_map InferShardableAxesFromSink( + const pir::Operation* sink, + const std::unordered_set& ops); } \ No newline at end of file