diff --git a/paddle/cinn/frontend/cluster_ops/fusion_helper.cc b/paddle/cinn/frontend/cluster_ops/fusion_helper.cc index 4cab75766cd55..42e6dff340e34 100644 --- a/paddle/cinn/frontend/cluster_ops/fusion_helper.cc +++ b/paddle/cinn/frontend/cluster_ops/fusion_helper.cc @@ -110,6 +110,83 @@ std::optional StmtFusionHelper::Fuse_PS_x_R_2_R( return FuseFilteredStmtPatterns(stmt_patterns); } +bool StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::FuseCondition(const StmtPattern& upstream, + const StmtPattern& downstream) { + return IsISPattern(upstream) && IsPSPattern(downstream); +} +std::variant StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::MergePattern( + const StmtPattern& upstream, const StmtPattern& downstream) { + return MergePatternImpl(std::get(upstream), std::get(downstream)); +} +std::variant StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::MergePatternImpl( + const IS& upstream, const PS& downstream) { + const auto& ops = [&] { + std::vector ops(upstream.ops.begin(), + upstream.ops.end()); + for (const auto* downstream_op : downstream.ops) { + if (std::find(ops.begin(), ops.end(), downstream_op) == ops.end()) { + ops.push_back(downstream_op); + } + } + return ops; + }(); + const auto& shardable_axes_signature = + MergeShardableAxesSignature(upstream, downstream); + return StmtPattern(PS{ + .ops = ops, + .sole_sink = downstream.sole_sink, + .shardable_axes_signature = shardable_axes_signature, + }); +} + +ShardableAxesSignature StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::MergeShardableAxesSignature( + const IS& upstream, const PS& downstream) { + LOG(FATAL) << "TODO(tianchao)"; +} + + +bool StmtFusionHelper::FusePolicy_IS_x_R_2_R::FuseCondition(const StmtPattern& upstream, + const StmtPattern& downstream) { + return IsISPattern(upstream) && IsRPattern(downstream); +} +std::variant StmtFusionHelper::FusePolicy_IS_x_R_2_R::MergePattern( + const StmtPattern& upstream, const StmtPattern& downstream) { + return MergePatternImpl(std::get(upstream), std::get(downstream)); +} +std::variant StmtFusionHelper::FusePolicy_IS_x_R_2_R::MergePatternImpl( + const IS& upstream, const R& downstream) { + if (downstream.HasFusedInput()) { + return ErrorGroupPattern{ + .ops = {downstream.reduce_op_pattern.reduce_op}, + .error_string = "The input of reduce has been fused.", + }; + } + R new_pattern = R(downstream); + new_pattern.input = upstream; + return StmtPattern(std::move(new_pattern)); +} + +bool StmtFusionHelper::FusePolicy_PS_x_R_2_R::FuseCondition(const StmtPattern& upstream, + const StmtPattern& downstream) { + return IsISPattern(upstream) && IsRPattern(downstream); +} +std::variant StmtFusionHelper::FusePolicy_PS_x_R_2_R::MergePattern( + const StmtPattern& upstream, const StmtPattern& downstream) { + return MergePatternImpl(std::get(upstream), std::get(downstream)); +} +std::variant StmtFusionHelper::FusePolicy_PS_x_R_2_R::MergePatternImpl( + const PS& upstream, const R& downstream) { + if (downstream.HasFusedInput()) { + return ErrorGroupPattern{ + .ops = {downstream.reduce_op_pattern.reduce_op}, + .error_string = "The input of reduce has been fused.", + }; + } + R new_pattern = R(downstream); + new_pattern.input = upstream; + return StmtPattern(new_pattern); +} + StmtPattern StmtFusionHelper::ConvertToStmtPattern(const pir::Operation* op) { const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); if (IsInjectiveSource(op)) {