diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h index 9b805cb891a569..b9582a9e6098b8 100644 --- a/paddle/cinn/api/op_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -23,35 +23,36 @@ struct PartialShardablePattern {}; // Reduce base pattern template struct ReductionPattern { - explicit ReductionPattern(const ReductionPattern& other) = default; + using Nothing = std::monostate; - std::variant, PartialShardablePattern> opt_inputs; + std::variant, PartialShardablePattern> input; SingleReductionOpPattern reduction_op_pattern; + + bool HasFusedInput() const { + return !std::holds_alternative(this->input); + } }; -// // Stmt := IS | R | PS -// // ops in StmtPattern will be lowered into a inlined cuda code. -// template -// using StmtPattern = std::variant, ReductionPattern, PartialShardablePattern>; - -// // Stmts := [Stmt] -// template -// using StmtsPattern = std::list>; - -// // fuse rules: -// // 1. IS * IS -> IS -// // 2. PS * PS -> PS -// // 3. IS * PS -> PS -// // 4. IS * R -> R -// // 5. PS * R -> R - -// // lifting rules: -// // 1. R -> Stmts -// // 2. PS -> Stmts -// // 3. Stmts * Stmts -> Stmts - -// // OpTopoPattern := Error | Stmts -// template -// using OpTopoPattern = std::variant, StmtsPattern>; +// Stmt := IS | R | PS +// ops in StmtPattern will be lowered into a inlined cuda code. +template +using StmtPattern = std::variant, ReductionPattern, PartialShardablePattern>; + +// Stmts := [Stmt] +template +using StmtsPattern = std::vector>; +// fuse rules: +// 1. IS * IS -> IS +// 2. PS * PS -> PS +// 3. IS * PS -> PS +// 4. IS * R -> R +// 5. PS * R -> R +// lifting rules: +// 1. R -> Stmts +// 2. PS -> Stmts +// 3. Stmts * Stmts -> Stmts +// OpTopoPattern := Error | Stmts +template +using OpTopoPattern = std::variant, StmtsPattern>; } diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index 5fcfebc3df68cd..ea69cc1db06ca0 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -8,30 +8,37 @@ #include "paddle/cinn/api/op_topo_pattern.h" #include "paddle/pir/include/core/operation.h" #include "glog/logging.h" +#include "paddle/cinn/adt/adt.h" -namespace cinn::api { - -struct FrontendPattern {}; +namespace cinn::frontend { -template<> -struct ErrorPattern { - explicit ErrorPattern(const ErrorPattern& other) = default; +struct OpAndOperandIndex { + const pir::Operation* op; + const int operand_index; - std::vector ops; - std::string error_string; + bool operator==(const OpAndOperandIndex& other) const { + return this->op == other.op && this->operand_index == other.operand_index; + } }; -template<> -struct InjectiveSourcePattern { - explicit InjectiveSourcePattern(const InjectiveSourcePattern& other) = default; - std::vector ops; -}; +} + +namespace std { template<> -struct SingleReductionOpPattern { - explicit SingleReductionOpPattern(const SingleReductionOpPattern& other) = default; - const pir::Operation* reduce_op; +struct hash { + + size_t operator()(const cinn::frontend::OpAndOperandIndex& op_operand) const { + return cinn::adt::hash_combine(std::hash()(op_operand.op), op_operand.operand_index); + } }; + +} + +namespace cinn::frontend { + +struct FrontendPattern {}; + struct ShardableAxis { int axis; std::string axis_name; @@ -100,29 +107,40 @@ struct ShardableAxesUtil { }; struct ShardableAxesSignature { - using OpOperand = std::pair; - ShardableAxes output_shardable_axes; - std::unordered_map input_shardable_axes; + std::unordered_map input_shardable_axes; }; +} + +namespace cinn::api { + template<> -struct PartialShardablePattern { - explicit PartialShardablePattern(const PartialShardablePattern& other) = default; +struct ErrorPattern { + std::vector ops; + std::string error_string; +}; + +template<> +struct InjectiveSourcePattern { + std::vector ops; +}; +template<> +struct SingleReductionOpPattern { + const pir::Operation* reduce_op; +}; +template<> +struct PartialShardablePattern { std::vector ops; - ShardableAxesSignature shardable_axes_signature; + frontend::ShardableAxesSignature shardable_axes_signature; }; } namespace cinn::frontend { -using IS = api::InjectiveSourcePattern; -using R = api::ReductionPattern; -using PS = api::PartialShardablePattern; -using StmtPattern = std::variant; -using ErrorGroupPattern = api::ErrorPattern; -using GroupPattern = std::variant; +using ErrorGroupPattern = api::ErrorPattern; +using GroupPattern = api::OpTopoPattern; } \ No newline at end of file diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index 8f560c3342e48a..6a61ee71ea33c0 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -2,6 +2,10 @@ #include "paddle/cinn/common/topo_walker.h" #include "paddle/cinn/common/bfs_walker.h" #include "paddle/cinn/hlir/framework/op.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" + #include #include #include @@ -12,7 +16,13 @@ namespace cinn::frontend { namespace { using OpPatternKind = cinn::hlir::framework::OpPatternKind; -using StmtIter = std::list::iterator; +using IS = api::InjectiveSourcePattern; +using R = api::ReductionPattern; +using PS = api::PartialShardablePattern; +using StmtPattern = api::StmtPattern; +using StmtsPattern = api::StmtsPattern; + +using StmtIter = StmtPattern*; using OpVisitor = std::function; using NodeVisitor = std::function; @@ -28,7 +38,7 @@ bool IsGeneralInjective(const pir::Operation* op) { || op_pattern_kind == hlir::framework::kInjective; } -bool IsISPattern(StmtPattern& pattern){ +bool IsISPattern(const StmtPattern& pattern){ return std::holds_alternative(pattern); } @@ -52,6 +62,7 @@ void VisitOutputOp(const pir::Operation* op, const OpVisitor& DoEach) { pir::Value output = op->result(i); for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) { const auto* consumer_op = consumer_it->owner(); + if (consumer_op->isa()) continue; DoEach(consumer_op); } } @@ -66,7 +77,7 @@ void VisitStmtOpImpl(const IS& injective_source, const DoEachT& DoEach) { template void VisitStmtOpImpl(const R& reduce, const DoEachT& DoEach) { - DoEach(reduce.reduce_op); + DoEach(reduce.reduction_op_pattern.reduce_op); } template @@ -82,9 +93,9 @@ void VisitStmtOp(const StmtPattern& stmt, const DoEachT& DoEach) { } std::function MakePredicatorIsInThisFusionOp(const cinn::dialect::FusionOp& fusion_op) { - std::set set; - for (const pir::Operation* op : fusion_op.block()->ops()) { - if (!op->isa()) { + std::set set; + for (const pir::Operation* op : fusion_op.GetOperators()) { + if (!op->isa<::pir::YieldOp>()) { set.insert(op); } } @@ -121,7 +132,7 @@ std::function MakePredicatorIsInjectiveSource( return starts; }(); - std::unordered_map op_2_is_injective_source; + std::unordered_map op_2_is_injective_source; auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) { bool is_inputs_all_injective_source = true; @@ -135,8 +146,8 @@ std::function MakePredicatorIsInjectiveSource( return is_inputs_all_injective_source; }; - common::TopoWalker walker{VisitEachInput, VisitEachOutput}; - walker(starts, [&](const pir::Operation* op){ + common::TopoWalker walker{VisitInputOp, VisitOutputOp}; + walker(starts.begin(), starts.end(), [&](const pir::Operation* op){ op_2_is_injective_source[op] = (IsGeneralInjective(op) && IsInputsAllInjectiveSource(op)); }); return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) { @@ -154,8 +165,8 @@ class StmtFusionHelper { this->IsInjectiveSource = MakePredicatorIsInjectiveSource(fusion_op_, this->IsInThisFusionOp); } - std::list ConvertToStmtsPattern() const { - std::list ret; + std::vector ConvertToStmtsPattern() const { + std::vector ret; for (const auto* op : fusion_op_.GetOperators()) { if (!IsInThisFusionOp(op)) continue; ret.emplace_back(ConvertToStmtPattern(op)); @@ -163,12 +174,12 @@ class StmtFusionHelper { return ret; } - std::optional Fuse_IS_x_IS_2_IS(std::list* stmts) const { + std::optional Fuse_IS_x_IS_2_IS(std::vector* stmt_patterns) const { const auto ConstructISPattern = [&](const auto& ops) { return IS{ops}; }; - return MultiFuse(IsISPattern, ConstructISPattern, stmts); + return MultiFuse(IsISPattern, ConstructISPattern, stmt_patterns); } - std::optional Fuse_PS_x_PS_2_PS(std::list* stmt_patterns) const { + std::optional Fuse_PS_x_PS_2_PS(std::vector* stmt_patterns) const { const auto ConstructPSPattern = [&](const auto& ops) { const auto shardable_axes_signature = GetShardableAxesSignature(ops); return PS{ @@ -176,7 +187,7 @@ class StmtFusionHelper { .shardable_axes_signature=shardable_axes_signature, }; }; - return MultiFuse(IsPSPattern, ConstructISPattern, stmts); + return MultiFuse(IsPSPattern, ConstructPSPattern, stmt_patterns); } struct FusePolicy_IS_x_PS_2_PS { @@ -198,14 +209,20 @@ class StmtFusionHelper { return ops; }(); const auto& shardable_axes_signature = MergeShardableAxesSignature(upstream, downstream); - return PS{ + return StmtPattern(PS{ .ops=ops, .shardable_axes_signature=shardable_axes_signature, - }; + }); + } + + static ShardableAxesSignature MergeShardableAxesSignature( + const IS& upstream, + const PS& downstream) { + LOG(FATAL) << "TODO(tianchao)"; } }; - std::optional Fuse_IS_x_PS_2_PS(std::list* stmt_patterns) const { + std::optional Fuse_IS_x_PS_2_PS(std::vector* stmt_patterns) const { return FuseFilteredStmtPatterns(stmt_patterns); } struct FusePolicy_IS_x_R_2_R { @@ -219,19 +236,19 @@ class StmtFusionHelper { static std::variant MergePatternImpl( const IS& upstream, const R& downstream) { - if (downstream.opt_inputs.has_value()) { + if (downstream.HasFusedInput()) { return ErrorGroupPattern{ .ops={downstream.reduction_op_pattern.reduce_op}, .error_string="The input of reduce has been fused.", }; } R new_pattern = R(downstream); - new_pattern.opt_inputs = upstream; - return new_pattern; + new_pattern.input = upstream; + return StmtPattern(std::move(new_pattern)); } }; - std::optional Fuse_IS_x_R_2_R(std::list* stmt_patterns) const { + std::optional Fuse_IS_x_R_2_R(std::vector* stmt_patterns) const { return FuseFilteredStmtPatterns(stmt_patterns); } @@ -246,19 +263,19 @@ class StmtFusionHelper { static std::variant MergePatternImpl( const PS& upstream, const R& downstream) { - if (downstream.opt_inputs.has_value()) { + if (downstream.HasFusedInput()) { return ErrorGroupPattern{ .ops={downstream.reduction_op_pattern.reduce_op}, .error_string="The input of reduce has been fused.", }; } R new_pattern = R(downstream); - new_pattern.opt_inputs = upstream; - return new_pattern; + new_pattern.input = upstream; + return StmtPattern(new_pattern); } }; - std::optional Fuse_PS_x_R_2_R(std::list* stmt_patterns) const { + std::optional Fuse_PS_x_R_2_R(std::vector* stmt_patterns) const { return FuseFilteredStmtPatterns(stmt_patterns); } @@ -275,7 +292,7 @@ class StmtFusionHelper { } else if (kind == hlir::framework::kBroadcast) { return ConvertOpToPS(op); } else { - LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->op_name(); + LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name(); } LOG(FATAL) << "Dead code"; } @@ -296,11 +313,11 @@ class StmtFusionHelper { }; } - static std::function(const pir::Operation*)> - MakeStmtFinderFromOp(std::list* stmts) { + using StmtIter4OpT = std::function(const pir::Operation*)>; + static StmtIter4OpT MakeStmtFinderFromOp(std::vector* stmts) { std::unordered_map op2stmt_iter; - for (auto iter = stmts->begin(); iter != stmts->end(); ++iter) { - VisitStmtOp(*iter, [&](const auto* op) { op2stmt_iter[op] = iter; }); + for (auto& stmt : *stmts) { + VisitStmtOp(stmt, [&](const auto* op) { op2stmt_iter[op] = &stmt; }); } return [map=std::move(op2stmt_iter)](const pir::Operation* op) -> std::optional { const auto iter = map.find(op); @@ -309,8 +326,8 @@ class StmtFusionHelper { }; } - std::function MakeTopoOrderFinderOfOp(cinn::dialect::FusionOp& fusion_op) const { - std::unordered_map op2order_in_block; + std::function MakeTopoOrderFinderOfOp(const cinn::dialect::FusionOp& fusion_op) const { + std::unordered_map op2order_in_block; size_t order = 0; for (const pir::Operation* op : fusion_op.GetOperators()) { op2order_in_block[op] = ++order; @@ -322,18 +339,17 @@ class StmtFusionHelper { }; } - template + template std::optional MultiFuse( - const IsDetailPatternT& IsDetailPattern, + const IsChozenPatternT& IsChozenPattern, const ConstructPatternT& ConstructPattern, - std::list* stmts) const { + std::vector* stmts) const { const auto StmtFinder = MakeStmtFinderFromOp(stmts); - const auto VisitInputStmt = [&](StmtIter stmt, const NodeVisitor& DoEach) { VisitStmtOp(*stmt, [&](const auto* op){ VisitInputOp(op, [&](const pir::Operation* input) { if (const auto& input_stmt = StmtFinder(input)) { - if (IsDetailPattern(input_stmt->value())) { + if (IsChozenPattern(*input_stmt.value())) { DoEach(input_stmt.value()); } } @@ -344,7 +360,7 @@ class StmtFusionHelper { VisitStmtOp(*stmt, [&](const auto* op){ VisitOutputOp(op, [&](const pir::Operation* output) { if (const auto& output_stmt = StmtFinder(output)) { - if (IsDetailPattern(*output_stmt.value())) { + if (IsChozenPattern(*output_stmt.value())) { DoEach(output_stmt.value()); } } @@ -352,10 +368,10 @@ class StmtFusionHelper { }); }; const auto IsSinkPattern = [&](StmtIter stmt) { - if (!IsDetailPattern(*stmt)) return false; + if (!IsChozenPattern(*stmt)) return false; std::size_t num_injective_src_outputs = 0; - VisitOutputStmt(node, [&](const auto& consumer) { - num_injective_src_outputs += IsDetailPattern(*consumer); + VisitOutputStmt(stmt, [&](const auto& consumer) { + num_injective_src_outputs += IsChozenPattern(*consumer); }); return num_injective_src_outputs == 0; }; @@ -366,25 +382,30 @@ class StmtFusionHelper { common::BfsWalker reverse_walker(VisitInputStmt); const auto& GetUpstreamOps = [&](const auto stmt_iter) { std::vector visited_ops; - reverse_walker(start, [&](const auto node){ - VisitStmtOp(node, [&](const auto* op) { visited_ops.push_back(op); }); + reverse_walker(stmt_iter, [&](const auto node){ + VisitStmtOp(*node, [&](const auto* op) { visited_ops.push_back(op); }); }); std::sort(visited_ops.begin(), visited_ops.end(), Cmp); return visited_ops; }; - std::list fused_stmts; - for (auto stmt_iter = stmts->begin(); stmt_iter != stmts->end(); ++stmt_iter) { - if (!IsSinkPattern(stmt_iter)) continue; - fused_stmts.emplace_back(ConstructPattern(GetUpstreamOps(stmt_iter))); - } - for (auto stmt_iter = stmts->begin(); stmt_iter != start->end();) { - if (IsDetailPattern(*stmt_iter)) { - stmt_iter = stmts->erase(stmt_iter); - } else { - ++stmt_iter; + + std::vector ret_stmts = [&]{ + std::vector ret_stmts; + ret_stmts.reserve(stmts->size()); + for (const auto& stmt : *stmts) { + if (!IsChozenPattern(stmt)) { + ret_stmts.push_back(stmt); + } else { + // do nothing. + } } + return ret_stmts; + }(); + for (auto& stmt : *stmts) { + if (!IsSinkPattern(&stmt)) continue; + ret_stmts.emplace_back(ConstructPattern(GetUpstreamOps(&stmt))); } - stmts->splice(stmts->begin(), std::move(fused_stmts)); + *stmts = ret_stmts; return std::nullopt; } @@ -399,7 +420,7 @@ class StmtFusionHelper { } else if (kind == hlir::framework::kBroadcast) { return MakeShardableAxesSignature4BroadcastOp(op); } else { - LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->op_name(); + LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name(); } LOG(FATAL) << "Dead code"; } @@ -424,13 +445,13 @@ class StmtFusionHelper { CHECK(rank.has_value()); return rank.value(); }(); - const ShardableAxes shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank); - std::unordered_map input_shardable_axes; + const ShardableAxes output_shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank); + std::unordered_map input_shardable_axes; for (int i = 0; i < op->num_operands(); ++i) { - input_shardable_axes[std::pair(op, i)] = shardable_axes; + input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes; } return ShardableAxesSignature{ - .output_shardable_axes, + .output_shardable_axes=output_shardable_axes, .input_shardable_axes=input_shardable_axes, }; } @@ -440,45 +461,44 @@ class StmtFusionHelper { } struct StmtIterPair { - StmtIter upstream_iter; - StmtIter downstream_iter; + std::list::iterator upstream_iter; + std::list::iterator downstream_iter; }; - bool IsConnected(const StmtIter& upstream, const StmtIter& downstream){ - const auto StmtFinder = MakeStmtFinderFromOp({*upstream, *downstream}); + bool IsConnected(const StmtIter4OpT& StmtFinder, const StmtIter& upstream, const StmtIter& downstream) const { const auto VisitInputStmt = [&](StmtIter stmt, const NodeVisitor& DoEach) { - VisitStmtOp(*stmt, [&](const auto* op)){ + VisitStmtOp(*stmt, [&](const auto* op){ VisitInputOp(op, [&](const pir::Operation* input) { if (const auto& input_stmt = StmtFinder(input)) { - if (IsDetailPattern(input_stmt->value())) { - DoEach(input_stmt.value()); - } + DoEach(input_stmt.value()); } }); - }; + }); }; - auto downstream_input_patterns = std::unordered_set(); - VisitInputStmt(*downstream, [&](const StmtIter& input_pattern){ - downstream_input_patterns.insert(input_pattern); - }) - - return downstream_input_patterns.count(upstream) > 0; + bool found = false; + VisitInputStmt(downstream, [&](const StmtIter& input_pattern){ + if (input_pattern == upstream) { + found = true; + } + }); + return found; } template std::optional FindConnetedPattenPairWithCondition( - std::list* stmt_patterns, + const StmtIter4OpT& StmtFinder, + std::list* stmt_iters, const FuseTargetConditionT& FuseTargetCondition) const { - for (auto dst_iter = stmt_patterns->begin(); dst_iter != stmt_patterns->end(); ++dst_iter) { - for (auto src_iter = stmt_patterns->begin(); src_iter != stmt_patterns->end(); ++src_iter) { + for (auto dst_iter = stmt_iters->begin(); dst_iter != stmt_iters->end(); ++dst_iter) { + for (auto src_iter = stmt_iters->begin(); src_iter != stmt_iters->end(); ++src_iter) { if (src_iter == dst_iter) continue; - if (!IsConnected(*src_iter, *dst_iter)) continue; - if (FuseTargetCondition(*src_iter, *dst_iter)) { - return StmtPattern{ + if (!IsConnected(StmtFinder, *src_iter, *dst_iter)) continue; + if (FuseTargetCondition(**src_iter, **dst_iter)) { + return StmtIterPair{ .upstream_iter=src_iter, .downstream_iter=dst_iter, - } + }; } } } @@ -487,21 +507,44 @@ class StmtFusionHelper { template std::optional FuseFilteredStmtPatterns( - std::list* stmt_patterns) const{ + std::vector* stmt_patterns) const{ + std::list stmts_iters = [&]{ + std::list stmts_iters; + for (auto& stmt : *stmt_patterns) { + stmts_iters.push_back(&stmt); + } + return stmts_iters; + }(); + const auto StmtFinder = MakeStmtFinderFromOp(stmt_patterns); + const auto EraseOld = [&](const StmtIterPair& pattern_pair) { + stmts_iters.erase(pattern_pair.upstream_iter); + stmts_iters.erase(pattern_pair.downstream_iter); + }; + const auto& InsertNew = [&](const StmtPattern& stmt_pattern) { + stmt_patterns->push_back(stmt_pattern); + stmts_iters.push_back(&stmt_patterns->back()); + }; while(true){ const auto& pattern_pair = FindConnetedPattenPairWithCondition( - stmt_patterns, &FusionPolicy::FuseCondition); - if (!pattern_pair.value()) break; + StmtFinder, &stmts_iters, &FusionPolicy::FuseCondition); + if (!pattern_pair.has_value()) break; const std::variant& new_pattern = - FusionPolicy::MergePattern(*pattern_pair.value().upstream_iter, *pattern_pair.value().downstream_iter); + FusionPolicy::MergePattern(**pattern_pair.value().upstream_iter, **pattern_pair.value().downstream_iter); - if (std::holds_alternative(new_pattern)){ + if (std::holds_alternative(new_pattern)) { return std::get(new_pattern); } - stmt_patterns->erase(pattern_pair.value().upstream_iter); - stmt_patterns->erase(pattern_pair.value().downstream_iter); - stmt_patterns->emplace_back(std::get(new_pattern)); + EraseOld(pattern_pair.value()); + InsertNew(std::get(new_pattern)); } + *stmt_patterns = [&]{ + std::vector ret_patterns; + ret_patterns.reserve(stmts_iters.size()); + for (const auto& stmt_iter : stmts_iters) { + ret_patterns.push_back(*stmt_iter); + } + return ret_patterns; + }(); return std::nullopt; } @@ -542,28 +585,28 @@ class StmtFusionHelper { return ReversedInferShardableAxes(reversed_walker, sink, init_sa); }(); const auto& IsInputOpOperand = [&](const auto* op, int input_idx) { - const auto& defining_op = op->operand_source(input_idx)->defining_op(); + const auto& defining_op = op->operand_source(input_idx).defining_op(); return IsInThisFusionOp(defining_op) && ops_set.count(defining_op) == 0; }; - using OpOperandT = std::pair; const auto& input_op_operands = [&]{ - std::vector op_operands; + std::vector op_operands; for (const auto* op : ops) { for (int i = 0; i < op->num_operands(); ++i) { if (!IsInputOpOperand(op, i)) continue; - op_operands.emplace_back({op, i}); + op_operands.emplace_back(OpAndOperandIndex{op, i}); } } return op_operands; }(); const auto& shardable_axes_sig = [&]{ ShardableAxesSignature signature; - ShardableAxesSignature.output_shardable_axes = value2shardable_axes.at(sink->result(0)); + signature.output_shardable_axes = value2shardable_axes.at(sink->result(0)); for (const auto& pair : input_op_operands) { const auto& [op, idx] = pair; pir::Value input = op->operand_source(idx); - ShardableAxesSignature.input_shardable_axes[pair] = value2shardable_axes.at(input); + signature.input_shardable_axes[pair] = value2shardable_axes.at(input); } + return signature; }(); return shardable_axes_sig; } @@ -607,7 +650,7 @@ class StmtFusionHelper { GroupPattern FuseToGroupPattern(const cinn::dialect::FusionOp& fusion_op) { StmtFusionHelper helper(fusion_op); - std::list stmt_patterns = helper.ConvertToStmtsPattern(); + std::vector stmt_patterns = helper.ConvertToStmtsPattern(); if (const auto& error = helper.Fuse_IS_x_IS_2_IS(&stmt_patterns)) return error.value(); if (const auto& error = helper.Fuse_PS_x_PS_2_PS(&stmt_patterns)) return error.value(); if (const auto& error = helper.Fuse_IS_x_PS_2_PS(&stmt_patterns)) return error.value();