From 08eb16d3211a4b0725ca0b633bd55ce5c77de672 Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Wed, 6 Mar 2024 13:53:52 +0000 Subject: [PATCH] update --- paddle/cinn/api/op_topo_pattern.h | 6 +-- paddle/cinn/frontend/group_pattern.h | 2 +- paddle/cinn/frontend/group_pattern_util.cc | 58 ++++++++++++++++++---- 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h index 1273b0b37280a..5d680bfd960f3 100644 --- a/paddle/cinn/api/op_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -20,7 +20,7 @@ struct PartialShardablePattern {}; template struct ReductionPattern { using Nothing = std::monostate; - std::variant, PartialShardablePattern> opt_is_or_ps_input; + std::variant, PartialShardablePattern> opt_inputs; SingleReductionOpPattern reduction_op_pattern; }; @@ -30,8 +30,8 @@ template using ShardableReductionsPattern = std::vector, PartialShardablePattern>>; // fuse rules: -// 1. IS * PS -> PS -// 2. PS * PS -> PS +// 1. PS * PS -> PS +// 2. IS * PS -> PS // 3. IS * R -> R // 4. PS * R -> R diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index b45c05f79a706..75be679021ab5 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -30,7 +30,7 @@ struct ShardableAxes { struct ShardableAxesSignature { using OpOperand = std::pair; - ShardableAxes output_shardable_axes; + std::vector output_shardable_axes; std::unordered_map input_shardable_axes; }; diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index e42b77dc2017a..87194b60760d2 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -113,29 +113,67 @@ struct InternalFusionHelper { TODO(); } + std::variant MergePattern( + const IS& upstream, + const PS& downstream){ + PS new_pattern = CopyPattern(downstream); + new_pattern.ops.insert(new_pattern.end(), upstream.begin(), upstream.end()); + return new_pattern; + } + + std::variant MergePattern( + const PS& upstream, + const PS& downstream){ + PS new_pattern = CopyPattern(downstream); + new_pattern.ops.insert(new_pattern.end(), upstream.begin(), upstream.end()); + new_pattern.shardable_axes_signature.output_shardable_axes.insert( + new_pattern.shardable_axes_signature.output_shardable_axes.end(), + upstream.shardable_axes_signature.output_shardable_axes.begin(), + upstream.shardable_axes_signature.output_shardable_axes.end() + ); + new_pattern.shardable_axes_signature.input_shardable_axes.insert( + upstream.shardable_axes_signature.input_shardable_axes.begin(), + upstream.shardable_axes_signature.input_shardable_axes.end() + ); + return new_pattern + } + + std::variant MergePattern( + const IS& upstream, + const R& downstream){ + R new_pattern = CopyPattern(downstream); + new_pattern.opt_inputs = CopyPattern(upstream); + return new_pattern; + } + + std::variant MergePattern( + const PS& upstream, + const R& downstream){ + R new_pattern = CopyPattern(downstream); + new_pattern.opt_inputs = CopyPattern(upstream); + return new_pattern; + } + std::optional> FindConnetedPattenPairWithCondition( std::vector* internal_patterns, - std::function& FuseTargetCondition /* first input is upstream, second is downstream */) const { + std::function& FuseTargetCondition) const { for (int i=0; i FuseIternalPattenPrototype( std::vector* internal_patterns, std::function& FuseTargetCondition) const{ @@ -147,7 +185,9 @@ struct InternalFusionHelper { if (!pattern_pair.value()){ break; } - const InternalPattern& new_pattern = MergePattern(pattern_pair.first, pattern_pair.second); + const std::variant& new_pattern = + MergePattern(pattern_pair.first, pattern_pair.second); + if (IsErrorGroupPattern(new_pattern)){ return new_pattern; } @@ -202,8 +242,8 @@ std::variant, ErrorGroupPattern> InternalFusion(con const auto& IsInjectiveSource = MakeGetterIsInjectiveSource(fusion_op, IsInThisFusionOp); InternalFusionHelper helper{IsInThisFusionOp, IsInjectiveSource}; std::vector internal_patterns = helper.FuseISAndConvertRemainder(fusion_op); - if (const auto& opt_error = helper.Fuse_IS_x_PS_2_PS(&internal_patterns)) return opt_error.value(); if (const auto& opt_error = helper.Fuse_PS_x_PS_2_PS(&internal_patterns)) return opt_error.value(); + if (const auto& opt_error = helper.Fuse_IS_x_PS_2_PS(&internal_patterns)) return opt_error.value(); if (const auto& opt_error = helper.Fuse_IS_x_R_2_R(&internal_patterns)) return opt_error.value(); if (const auto& opt_error = helper.Fuse_PS_x_R_2_R(&internal_patterns)) return opt_error.value(); return internal_patterns;