From b248e382f8fecb361e22f7744d6a14960029d9e1 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 9 Jan 2024 17:16:17 +0000 Subject: [PATCH 1/4] refine fully_insert_broadcast_pass --- .../operator/transforms/CMakeLists.txt | 4 +- ...dcast_pass.cc => insert_broadcast_pass.cc} | 86 +++++++++++-------- ...oadcast_pass.h => insert_broadcast_pass.h} | 10 +-- .../interface/infer_symbolic_shape.cc | 6 +- .../pir/transforms/shape_optimization_pass.cc | 3 +- paddle/pir/dialect/shape/utils/shape_utils.h | 4 + 6 files changed, 62 insertions(+), 51 deletions(-) rename paddle/cinn/hlir/dialect/operator/transforms/{fully_insert_broadcast_pass.cc => insert_broadcast_pass.cc} (53%) rename paddle/cinn/hlir/dialect/operator/transforms/{fully_insert_broadcast_pass.h => insert_broadcast_pass.h} (71%) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 9b57f4f14edea..bde1763097b5f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -30,9 +30,9 @@ if(NOT CINN_ONLY) op_dialect_vjp) cinn_cc_library( - fully_insert_broadcast_pass + insert_broadcast_pass SRCS - fully_insert_broadcast_pass.cc + insert_broadcast_pass.cc DEPS pir cinn_op_dialect diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc similarity index 53% rename from paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc rename to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index 04ba01b4cbea2..ec26b80598b74 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h" #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/framework/pir/utils.h" @@ -44,6 +44,14 @@ pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter, bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { pir::Value x = op->operand_source(0); pir::Value y = op->operand_source(1); + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(module_op.program()); + const auto& x_shape = shape_analysis.GetShapeOrDataForValue(&x); + const auto& y_shape = shape_analysis.GetShapeOrDataForValue(&y); + if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) { + return false; + } + pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y); { pir::Value broadcasted_x = @@ -59,7 +67,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { } template -class FullyInsertBroadcastPattern : public pir::OpRewritePattern { +class InsertBroadcastPattern : public pir::OpRewritePattern { public: using pir::OpRewritePattern::OpRewritePattern; @@ -69,42 +77,46 @@ class FullyInsertBroadcastPattern : public pir::OpRewritePattern { } }; -FullyInsertBroadcastPass::FullyInsertBroadcastPass() - : pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {} - -pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns( - pir::IrContext* context) { - pir::RewritePatternSet ps(context); - // elementwise ops - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>( - context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - - // compare ops - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - - // bitwise ops - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - - return ps; -} +class InsertBroadcastPass : public pir::PatternRewritePass { + public: + InsertBroadcastPass() : pir::PatternRewritePass("insert_broadcast_pass", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + // elementwise ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + // compare ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + // bitwise ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa() && op->num_regions() > 0; + } +}; -bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const { - return op->isa() && op->num_regions() > 0; +std::unique_ptr CreateInsertBroadcastPass() { + return std::make_unique(); } } // namespace ir diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h similarity index 71% rename from paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h rename to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h index ba17458399278..d3f5e489a682a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h @@ -15,20 +15,12 @@ #pragma once #include "paddle/pir/pass/pass.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" namespace cinn { namespace dialect { namespace ir { -class FullyInsertBroadcastPass : public pir::PatternRewritePass { - public: - FullyInsertBroadcastPass(); - - pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; - - bool CanApplyOn(pir::Operation *op) const override; -}; +IR_API std::unique_ptr CreateInsertBroadcastPass(); } // namespace ir } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index cd324b5f05c69..34eb504d08446 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -68,7 +68,11 @@ bool InferSymbolicShapeElementWiseBinary( std::vector shapes; symbol::DimExprBuilder builder{nullptr}; for (size_t i = 0; i < shape_0.size(); i++) { - shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); + if (shape_0[i] == shape_1[i]) { + shapes.emplace_back(shape_0[i]); + } else { + shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); + } } // TODO(lanxianghit): fill data when the operation is on shape computation diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 2bf33603fa7be..325f940417fa2 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -380,9 +380,8 @@ void DebugPrintOpInfo( } void InferSymExprForAllValues(ModuleOp module_op) { - auto shape_analysis_mgr = ShapeAnalysisManager::Instance(); ShapeConstraintIRAnalysis& shape_analysis = - shape_analysis_mgr.Get(module_op.program()); + ShapeAnalysisManager::Instance().Get(module_op.program()); for (uint32_t i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 09a2aba1d15f2..28d939ef7558f 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -126,6 +126,10 @@ class IR_API ShapeAnalysisManager { static ShapeAnalysisManager& Instance(); ShapeConstraintIRAnalysis& Get(pir::Program* program); + ShapeAnalysisManager(const ShapeAnalysisManager&) = delete; + ShapeAnalysisManager(ShapeAnalysisManager&&) = delete; + ShapeAnalysisManager& operator=(const ShapeAnalysisManager&) = delete; + private: ShapeAnalysisManager() {} std::unordered_map tables_; From ae007b4d912d19d813abc62e0163a4f2fc33962a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 10 Jan 2024 02:15:51 +0000 Subject: [PATCH 2/4] fix complie bug --- .../hlir/dialect/operator/transforms/insert_broadcast_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index ec26b80598b74..65f4f8d314b7a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -45,7 +45,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { pir::Value x = op->operand_source(0); pir::Value y = op->operand_source(1); pir::ShapeConstraintIRAnalysis& shape_analysis = - pir::ShapeAnalysisManager::Instance().Get(module_op.program()); + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); const auto& x_shape = shape_analysis.GetShapeOrDataForValue(&x); const auto& y_shape = shape_analysis.GetShapeOrDataForValue(&y); if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) { From 9736828b80b964726ae62c17680dcd140be11e9f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 10 Jan 2024 05:14:44 +0000 Subject: [PATCH 3/4] fix complie --- paddle/fluid/pybind/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index c12d278fcbed1..8a7d5420be39d 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -57,7 +57,7 @@ if(WITH_CINN) add_broadcast_to_elementwise_pass pd_to_cinn_pass sub_graph_checker - fully_insert_broadcast_pass + insert_broadcast_pass fuse_shape_ops_into_generate_shape_op_pass split_generate_shape_into_shape_ops_pass) endif() From 027bb8daeb2016570d07592259e21b277298997c Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 11 Jan 2024 13:20:30 +0000 Subject: [PATCH 4/4] fix conflict --- .../hlir/dialect/operator/transforms/insert_broadcast_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index 7abb585c7789e..7819bc362f577 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -52,8 +52,8 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { pir::Value y = op->operand_source(1); pir::ShapeConstraintIRAnalysis& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); - const auto& x_shape = shape_analysis.GetShapeOrDataForValue(&x); - const auto& y_shape = shape_analysis.GetShapeOrDataForValue(&y); + const auto& x_shape = shape_analysis.GetShapeOrDataForValue(x); + const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y); if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) { return false; }