diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 8439775348a49d..6c5f09c3ebe3d3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -11,9 +11,7 @@ if(NOT CINN_ONLY) cinn_runtime_dialect pir_compiler) - cc_library( - cinn_transforms - SRCS ${cinn_transforms_srcs} - DEPS ${cinn_transforms_deps}) + cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS + ${cinn_transforms_deps}) endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc similarity index 55% rename from paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc rename to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index e5347281e009a4..7819bc362f5774 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h" #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/framework/pir/utils.h" @@ -50,6 +50,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { } pir::Value x = op->operand_source(0); pir::Value y = op->operand_source(1); + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + const auto& x_shape = shape_analysis.GetShapeOrDataForValue(x); + const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y); + if (x_shape.shape() == y_shape.shape() && x_shape.data() == y_shape.data()) { + return false; + } + pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y); { pir::Value broadcasted_x = @@ -67,7 +75,7 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { } // namespace template -class FullyInsertBroadcastPattern : public pir::OpRewritePattern { +class InsertBroadcastPattern : public pir::OpRewritePattern { public: using pir::OpRewritePattern::OpRewritePattern; @@ -77,42 +85,46 @@ class FullyInsertBroadcastPattern : public pir::OpRewritePattern { } }; -FullyInsertBroadcastPass::FullyInsertBroadcastPass() - : pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {} - -pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns( - pir::IrContext* context) { - pir::RewritePatternSet ps(context); - // elementwise ops - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>( - context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - - // compare ops - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - - // bitwise ops - ps.Add>(context); - ps.Add>(context); - ps.Add>(context); - - return ps; -} +class InsertBroadcastPass : public pir::PatternRewritePass { + public: + InsertBroadcastPass() : pir::PatternRewritePass("insert_broadcast_pass", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + // elementwise ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + // compare ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + // bitwise ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa() && op->num_regions() > 0; + } +}; -bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const { - return op->isa() && op->num_regions() > 0; +std::unique_ptr CreateInsertBroadcastPass() { + return std::make_unique(); } } // namespace ir diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h similarity index 71% rename from paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h rename to paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h index ba174583992784..d3f5e489a682a5 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h @@ -15,20 +15,12 @@ #pragma once #include "paddle/pir/pass/pass.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" namespace cinn { namespace dialect { namespace ir { -class FullyInsertBroadcastPass : public pir::PatternRewritePass { - public: - FullyInsertBroadcastPass(); - - pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; - - bool CanApplyOn(pir::Operation *op) const override; -}; +IR_API std::unique_ptr CreateInsertBroadcastPass(); } // namespace ir } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 5d4cc10b205ba4..1d57702349c903 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -67,7 +67,11 @@ bool InferSymbolicShapeElementWiseBinary( std::vector shapes; symbol::DimExprBuilder builder{nullptr}; for (size_t i = 0; i < shape_0.size(); i++) { - shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); + if (shape_0[i] == shape_1[i]) { + shapes.emplace_back(shape_0[i]); + } else { + shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); + } } // TODO(lanxianghit): fill data when the operation is on shape computation diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 485bb8c15f8ba8..80ad7f6248403d 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -83,9 +83,8 @@ void DebugPrintOpInfo( } void InferSymExprForAllValues(ModuleOp module_op) { - auto shape_analysis_mgr = ShapeAnalysisManager::Instance(); ShapeConstraintIRAnalysis& shape_analysis = - shape_analysis_mgr.Get(module_op.program()); + ShapeAnalysisManager::Instance().Get(module_op.program()); for (uint32_t i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 47ec6d58637dd0..ca872d2baaec32 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -120,6 +120,10 @@ class IR_API ShapeAnalysisManager { static ShapeAnalysisManager& Instance(); ShapeConstraintIRAnalysis& Get(pir::Program* program); + ShapeAnalysisManager(const ShapeAnalysisManager&) = delete; + ShapeAnalysisManager(ShapeAnalysisManager&&) = delete; + ShapeAnalysisManager& operator=(const ShapeAnalysisManager&) = delete; + private: ShapeAnalysisManager() {} std::unordered_map tables_; diff --git a/paddle/pir/pass/utils.cc b/paddle/pir/pass/utils.cc index 91d5975a07b5dd..f56a41c472446a 100644 --- a/paddle/pir/pass/utils.cc +++ b/paddle/pir/pass/utils.cc @@ -18,10 +18,11 @@ namespace pir { namespace detail { void PrintHeader(const std::string &header, std::ostream &os) { - unsigned padding = (80 - header.size()) / 2; - os << "===" << std::string(73, '-') << "===\n"; + const size_t padding = 8; + size_t line_len = header.size() + ((padding - 3) * 2); + os << "===" << std::string(line_len, '-') << "===\n"; os << std::string(padding, ' ') << header << "\n"; - os << "===" << std::string(73, '-') << "===\n"; + os << "===" << std::string(line_len, '-') << "===\n"; } } // namespace detail