diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index ac2d213b77868..ba146aa0dbd07 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -157,6 +157,124 @@ std::function MakePredicatorIsInjectiveSource( }; } +size_t GetRank(pir::Value value) { + return value.type().dyn_cast().dims().size(); +} + +ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(const pir::Operation* op) { + CHECK(!op->isa()) << "reshape not supported. TODO(wuzhanfei)."; + const size_t rank = [&]{ + std::optional rank; + for (int i = 0; i < op->num_operands(); ++i) { + if (rank.has_value()) { + CHECK_EQ(rank.value(), GetRank(op->operand_source(i))); + } else { + rank = GetRank(op->operand_source(i)); + } + } + CHECK_EQ(op->num_results(), 1); + if (rank.has_value()) { + CHECK_EQ(rank.value(), GetRank(op->result(0))); + } else { + rank = GetRank(op->result(0)); + } + CHECK(rank.has_value()); + return rank.value(); + }(); + const ShardableAxes output_shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank); + std::unordered_map input_shardable_axes; + for (int i = 0; i < op->num_operands(); ++i) { + input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes; + } + return ShardableAxesSignature{ + .output_shardable_axes=output_shardable_axes, + .input_shardable_axes=input_shardable_axes, + }; +} + +ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(const pir::Operation* op) { + LOG(FATAL) << "TODO(wuzhanfei)."; +} + +ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) { + const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); + if (kind == hlir::framework::kElementWise) { + return MakeShardableAxesSignature4ElementWiseOp(op); + } else if (kind == hlir::framework::kBroadcast) { + return MakeShardableAxesSignature4BroadcastOp(op); + } else { + LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name(); + } + LOG(FATAL) << "Dead code"; +} + +std::unordered_map ReversedInferShardableAxes( + common::TopoWalker& reversed_walker, + const pir::Operation* sink, + const ShardableAxes& init_sa) { + std::unordered_map value2shardable_axes{ + {sink->result(0), init_sa} + }; + const auto& UpdateValue2ShardableAxes = [&](pir::Value value, const ShardableAxes& sa) { + auto iter = value2shardable_axes.find(value); + if (iter != value2shardable_axes.end()) { + iter->second = ShardableAxesUtil::GetCommonShardableAxes(iter->second, sa); + } else { + iter->second = sa; + } + }; + reversed_walker(sink, [&](const auto* op){ + auto shardable_axes_sig = MakeShardableAxesSignature4Op(op); + const auto& old2new = ShardableAxesUtil::GetOldName2NewName(shardable_axes_sig.output_shardable_axes, + value2shardable_axes.at(op->result(0))); + for (auto& pair : shardable_axes_sig.input_shardable_axes) { + const auto& [my_op, input_idx] = pair.first; + CHECK_EQ(my_op, op); + auto* input_shardable_axes = &pair.second; + ShardableAxesUtil::UpdateShardableAxes(old2new, input_shardable_axes); + pir::Value input_value = op->operand_source(input_idx); + UpdateValue2ShardableAxes(input_value, *input_shardable_axes); + } + }); + return value2shardable_axes; +} + +common::TopoWalker GetOpsTopoWalker(const std::vector& ops) { + using Cache = std::unordered_set; + auto ops_set = std::make_shared(ops.begin(), ops.end()); + const auto VisitUpStreamInOps = [ops_set](const pir::Operation* op, const OpVisitor& DoEach) { + VisitInputOp(op, [&](const auto* input){ + if (ops_set->count(input) == 0) return; + DoEach(input); + }); + }; + const auto VisitDownStreamInOps = [ops_set](const pir::Operation* op, const OpVisitor& DoEach) { + VisitOutputOp(op, [&](const auto* output){ + if (ops_set->count(output) == 0) return; + DoEach(output); + }); + }; + common::TopoWalker reversed_walker(VisitDownStreamInOps, VisitUpStreamInOps); + return reversed_walker; +} + +std::list GetStarts( + const common::TopoWalker& topo_walker, + const std::vector& ops) { + const auto IsStart = [&](const pir::Operation* op) { + size_t num_prevs = 0; + topo_walker.VisitPrevNodes(op, [&](const auto*){ ++num_prevs; }); + return num_prevs == 0; + }; + std::list starts; + for (const auto* op : ops) { + if (IsStart(op)) { + starts.push_back(op); + } + } + return starts; +} + class StmtFusionHelper { public: explicit StmtFusionHelper(const cinn::dialect::FusionOp& fusion_op) @@ -409,57 +527,6 @@ class StmtFusionHelper { return std::nullopt; } - size_t GetRank(pir::Value value) const { - return value.type().dyn_cast().dims().size(); - }; - - ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) const { - const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); - if (kind == hlir::framework::kElementWise) { - return MakeShardableAxesSignature4ElementWiseOp(op); - } else if (kind == hlir::framework::kBroadcast) { - return MakeShardableAxesSignature4BroadcastOp(op); - } else { - LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name(); - } - LOG(FATAL) << "Dead code"; - } - - ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(const pir::Operation* op) const { - CHECK(!op->isa()) << "reshape not supported. TODO(wuzhanfei)."; - const size_t rank = [&]{ - std::optional rank; - for (int i = 0; i < op->num_operands(); ++i) { - if (rank.has_value()) { - CHECK_EQ(rank.value(), GetRank(op->operand_source(i))); - } else { - rank = GetRank(op->operand_source(i)); - } - } - CHECK_EQ(op->num_results(), 1); - if (rank.has_value()) { - CHECK_EQ(rank.value(), GetRank(op->result(0))); - } else { - rank = GetRank(op->result(0)); - } - CHECK(rank.has_value()); - return rank.value(); - }(); - const ShardableAxes output_shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank); - std::unordered_map input_shardable_axes; - for (int i = 0; i < op->num_operands(); ++i) { - input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes; - } - return ShardableAxesSignature{ - .output_shardable_axes=output_shardable_axes, - .input_shardable_axes=input_shardable_axes, - }; - } - - ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(const pir::Operation* op) const { - LOG(FATAL) << "TODO(wuzhanfei)."; - } - struct StmtIterPair { std::list::iterator upstream_iter; std::list::iterator downstream_iter; @@ -550,36 +617,13 @@ class StmtFusionHelper { ShardableAxesSignature GetShardableAxesSignature(const std::vector& ops) const { std::unordered_set ops_set(ops.begin(), ops.end()); - const auto VisitUpStreamInOps = [&](const pir::Operation* op, const OpVisitor& DoEach) { - VisitInputOp(op, [&](const auto* input){ - if (ops_set.count(input) == 0) return; - DoEach(input); - }); - }; - const auto VisitDownStreamInOps = [&](const pir::Operation* op, const OpVisitor& DoEach) { - VisitOutputOp(op, [&](const auto* output){ - if (ops_set.count(output) == 0) return; - DoEach(output); - }); - }; - const auto IsSinkOp = [&](const pir::Operation* op) { - size_t num_donwstreams = 0; - VisitDownStreamInOps(op, [&](const auto*){ ++num_donwstreams; }); - return num_donwstreams == 0; - }; + auto reversed_walker = GetOpsTopoWalker(ops); const pir::Operation* sink = [&]{ - std::optional sink; - for (const auto* op : ops) { - if (IsSinkOp(op)) { - CHECK(!sink.has_value()) << "only one sink node."; - } - sink = op; - } - CHECK(sink.has_value()); - return sink.value(); + const auto& sinks = GetStarts(reversed_walker, ops); + CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node."; + return *sinks.begin(); }(); const auto& value2shardable_axes = [&]{ - common::TopoWalker reversed_walker(VisitDownStreamInOps, VisitUpStreamInOps); size_t rank = GetRank(sink->result(0)); const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank); return ReversedInferShardableAxes(reversed_walker, sink, init_sa); @@ -611,37 +655,6 @@ class StmtFusionHelper { return shardable_axes_sig; } - std::unordered_map ReversedInferShardableAxes( - common::TopoWalker& reversed_walker, - const pir::Operation* sink, - const ShardableAxes& init_sa) const { - std::unordered_map value2shardable_axes{ - {sink->result(0), init_sa} - }; - const auto& UpdateValue2ShardableAxes = [&](pir::Value value, const ShardableAxes& sa) { - auto iter = value2shardable_axes.find(value); - if (iter != value2shardable_axes.end()) { - iter->second = ShardableAxesUtil::GetCommonShardableAxes(iter->second, sa); - } else { - iter->second = sa; - } - }; - reversed_walker(sink, [&](const auto* op){ - auto shardable_axes_sig = MakeShardableAxesSignature4Op(op); - const auto& old2new = ShardableAxesUtil::GetOldName2NewName(shardable_axes_sig.output_shardable_axes, - value2shardable_axes.at(op->result(0))); - for (auto& pair : shardable_axes_sig.input_shardable_axes) { - const auto& [my_op, input_idx] = pair.first; - CHECK_EQ(my_op, op); - auto* input_shardable_axes = &pair.second; - ShardableAxesUtil::UpdateShardableAxes(old2new, input_shardable_axes); - pir::Value input_value = op->operand_source(input_idx); - UpdateValue2ShardableAxes(input_value, *input_shardable_axes); - } - }); - return value2shardable_axes; - } - private: cinn::dialect::FusionOp fusion_op_; std::function IsInThisFusionOp; @@ -665,4 +678,19 @@ GroupPattern GenerateGroupPatternFromFusionOp(const cinn::dialect::FusionOp& fus return FuseToGroupPattern(fusion_op); } +std::unordered_map InferShardableAxes(const std::vector& ops) { + auto reversed_walker = GetOpsTopoWalker(ops); + const pir::Operation* sink = [&]{ + const auto& sinks = GetStarts(reversed_walker, ops); + CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node."; + return *sinks.begin(); + }(); + const auto& value2shardable_axes = [&]{ + size_t rank = GetRank(sink->result(0)); + const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank); + return ReversedInferShardableAxes(reversed_walker, sink, init_sa); + }(); + return value2shardable_axes; +} + } \ No newline at end of file