From 6c35404b060591150062e81f3560c2ca194f8713 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 8 Jan 2024 13:13:52 +0000 Subject: [PATCH 1/4] [Fix Bug] Fix Bugs of Two Pass --- .../transforms/fully_insert_broadcast_pass.cc | 8 ++++ ...e_shape_ops_into_generate_shape_op_pass.cc | 41 +++++++++++-------- ...se_shape_ops_into_generate_shape_op_pass.h | 8 +--- paddle/fluid/pybind/CMakeLists.txt | 10 ++++- 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc index 04ba01b4cbea2..e5347281e009a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc @@ -32,6 +32,8 @@ namespace cinn { namespace dialect { namespace ir { +namespace { + pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter, pir::Value x, pir::Value y) { @@ -42,6 +44,10 @@ pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter, } bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { + if (op->operand_source(0).defining_op()->isa() && + op->operand_source(1).defining_op()->isa()) { + return false; + } pir::Value x = op->operand_source(0); pir::Value y = op->operand_source(1); pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y); @@ -58,6 +64,8 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { return true; } +} // namespace + template class FullyInsertBroadcastPattern : public pir::OpRewritePattern { public: diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 0bff6c7daa886..9a9057e993be7 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -27,6 +27,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pattern_rewrite/pattern_applicator.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" @@ -38,6 +39,9 @@ namespace ir { namespace { +using ShapeOrDataDimExprs4ValueT = + std::function; + std::vector FindSourceDenseTensorOfDimTensor( pir::Value shape, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { @@ -126,9 +130,19 @@ std::optional GetOutOfRewritedGenerateShapeOp( .out(); } -bool ProcessOp(paddle::dialect::ExpandOp op, - pir::PatternRewriter* rewriter, - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { +bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) { + if (op.shape().defining_op()->isa()) { + return false; + } + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value = + [&op](pir::Value value) -> symbol::ShapeOrDataDimExprs { + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get( + op.x().defining_op()->GetParentProgram()); + CHECK(shape_analysis.value_id_to_shapeordata_.find(GetValueId(&value)) != + shape_analysis.value_id_to_shapeordata_.end()); + return shape_analysis.value_id_to_shapeordata_.at(GetValueId(&value)); + }; std::optional opt_generated_shape = GetOutOfRewritedGenerateShapeOp( op.shape(), rewriter, ShapeOrDataDimExprs4Value); @@ -143,32 +157,25 @@ template class FuseShapeOpsIntoGenerateShapeOpPattern : public pir::OpRewritePattern { public: - FuseShapeOpsIntoGenerateShapeOpPattern( - pir::IrContext* context, - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) - : pir::OpRewritePattern(context), - ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {} + explicit FuseShapeOpsIntoGenerateShapeOpPattern(pir::IrContext* context) + : pir::OpRewritePattern(context) {} bool MatchAndRewrite(OPTYPE op, pir::PatternRewriter& rewriter) const override { - return ProcessOp(op, &rewriter, ShapeOrDataDimExprs4Value_); + return ProcessOp(op, &rewriter); } - - private: - ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_; }; -FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass( - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) - : pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1), - ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {} +FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass() + : pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1) { +} pir::RewritePatternSet FuseShapeOpsIntoGenerateShapeOpPass::InitializePatterns( pir::IrContext* context) { pir::RewritePatternSet ps(context); // elementwise ops ps.Add>( - context, ShapeOrDataDimExprs4Value_); + context); return ps; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h index 393ae49825182..3f74db98aba44 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h @@ -24,17 +24,11 @@ namespace ir { class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass { public: - using ShapeOrDataDimExprs4ValueT = - std::function; - explicit FuseShapeOpsIntoGenerateShapeOpPass( - const ShapeOrDataDimExprs4ValueT &ShapeOrDataDimExprs4Value); + FuseShapeOpsIntoGenerateShapeOpPass(); pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; bool CanApplyOn(pir::Operation *op) const override; - - private: - ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_; }; } // namespace ir diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 52eada6c5482f..e2b6b7009f3cf 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -51,8 +51,14 @@ set(PYBIND_DEPS if(WITH_CINN) set(PYBIND_DEPS - ${PYBIND_DEPS} pir_transforms op_with_group_merge_pass - add_broadcast_to_elementwise_pass pd_to_cinn_pass sub_graph_checker) + ${PYBIND_DEPS} + pir_transforms + op_with_group_merge_pass + add_broadcast_to_elementwise_pass + pd_to_cinn_pass + sub_graph_checker + fully_insert_broadcast_pass + fuse_shape_ops_into_generate_shape_op_pass) endif() if(WITH_PSCORE) From 1e21ed2b0da45ffcc5bf75dd4787bcdaf5f60ac9 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 8 Jan 2024 13:28:40 +0000 Subject: [PATCH 2/4] Fix GenerateShapeOp bug --- .../operator/ir/generate_shape_util.cc | 51 +++---------------- .../dialect/operator/ir/generate_shape_util.h | 12 +---- .../hlir/dialect/operator/ir/op_dialect.cc | 1 + 3 files changed, 10 insertions(+), 54 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index f64bb9269d63a..3d88e575709ea 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -314,51 +314,14 @@ class SubstituteDimExprHelper final { DimExpr4SymbolNameT DimExpr4SymbolName_; }; -std::optional SubstituteDimExpr( +DimExpr SubstituteDimExpr( const DimExpr& dim_expr, const std::function(const std::string& symbol_name)>& DimExpr4SymbolName) { - return SubstituteDimExprHelper(DimExpr4SymbolName).Substitute(dim_expr); -} - -std::function(const std::string& symbol_name)> -MakeGetterDimExpr4SymbolName( - const std::vector>& symbol_bindings, - const std::function( - int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim) { - std::unordered_map>> - symbol_name2in_tensor_dim_pos; - for (const auto& tuple : symbol_bindings) { - const auto& [symbol_name, in_tensor_idx, in_tensor_dim_idx] = tuple; - symbol_name2in_tensor_dim_pos[symbol_name].emplace_back( - std::pair{in_tensor_idx, in_tensor_dim_idx}); - } - return [map = std::move(symbol_name2in_tensor_dim_pos), DimExpr4InputDim]( - const std::string& symbol_name) -> std::optional { - const auto& iter = map.find(symbol_name); - if (iter == map.end()) { - return std::nullopt; - } - const auto& positions = iter->second; - std::optional ret = std::nullopt; - for (const auto& [in_tensor_idx, in_tensor_dim_idx] : positions) { - const auto& current = DimExpr4InputDim(in_tensor_idx, in_tensor_dim_idx); - if (!current.has_value()) { - return std::nullopt; - } - if (ret.has_value()) { - // Same names, same DimExprs. - if (ret.value() != current.value()) { - return std::nullopt; - } - } else { - ret = current; - } - } - return ret; - }; + const auto& opt_substituted = + SubstituteDimExprHelper(DimExpr4SymbolName).Substitute(dim_expr); + if (opt_substituted.has_value()) return opt_substituted.value(); + return dim_expr; } namespace { @@ -596,14 +559,14 @@ void GenerateSymbolBindings( std::vector GetMinimalInputs( const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, const std::vector& input_tensors) { - std::unordered_set handdled_dim_exprs; + std::unordered_set handled_dim_exprs; std::unordered_set first_occurred_input_tensors; auto TryCollectFirstOcurredInput_tensor = [&](pir::Value input_tensor, const std::vector& dim_exprs) { for (const auto& dim_expr : dim_exprs) { if (dim_expr.isa()) continue; - if (!handdled_dim_exprs.insert(dim_expr).second) { + if (handled_dim_exprs.insert(dim_expr).second) { first_occurred_input_tensors.insert(input_tensor); } } diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index 401c240f61e86..88ed0c2917543 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -29,19 +29,11 @@ ::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx, std::optional ConvertAttributeToDimExpr( ::pir::Attribute attribute); -std::optional SubstituteDimExpr( +symbol::DimExpr SubstituteDimExpr( const symbol::DimExpr& dim_expr, const std::function( const std::string& symbol_name)>& DimExpr4SymbolName); -std::function(const std::string& symbol_name)> -MakeGetterDimExpr4SymbolName( - const std::vector>& symbol_bindings, - const std::function( - int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim); - std::function(const std::string& symbol_name)> MakeGetterDimExpr4SymbolName( const GenerateShapeOp::SymbolBindings& symbol_bindings, @@ -49,7 +41,7 @@ MakeGetterDimExpr4SymbolName( DimExpr4InputDim); using ShapeOrDataDimExprs4ValueT = - std::function; + std::function; // Returns true if success. bool MakeGenerateShapeOpAttribute( diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc index 5c46fc4be85e5..5de339b3f4306 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc @@ -55,6 +55,7 @@ void OperatorDialect::initialize() { RegisterOp(); RegisterOp(); RegisterOp(); + RegisterOp(); RegisterAttribute(); RegisterAttribute(); } From 4bb268857e73f0da31b161b6b95b704e28702a6f Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 9 Jan 2024 02:17:53 +0000 Subject: [PATCH 3/4] Modify unit test --- test/cpp/pir/cinn/generate_shape_util_test.cc | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/test/cpp/pir/cinn/generate_shape_util_test.cc b/test/cpp/pir/cinn/generate_shape_util_test.cc index 4fc69c877eb5f..2799f34950beb 100644 --- a/test/cpp/pir/cinn/generate_shape_util_test.cc +++ b/test/cpp/pir/cinn/generate_shape_util_test.cc @@ -48,7 +48,7 @@ TEST(DimExprUtil, Convert) { TEST(DimExprUtil, Substitute) { DimExpr dim_expr = CreateExampleDimExpr(); - const auto& opt_expr = SubstituteDimExpr( + const auto& substituted_expr = SubstituteDimExpr( dim_expr, [](const std::string& str) -> std::optional { if (str == "S0") { return DimExpr("symbol0"); @@ -58,9 +58,8 @@ TEST(DimExprUtil, Substitute) { return std::nullopt; } }); - ASSERT_TRUE(opt_expr.has_value()); const auto& ret_expr = SubstituteDimExpr( - opt_expr.value(), [](const std::string& str) -> std::optional { + substituted_expr, [](const std::string& str) -> std::optional { if (str == "symbol0") { return DimExpr("S0"); } else if (str == "symbol1") { @@ -69,26 +68,19 @@ TEST(DimExprUtil, Substitute) { return std::nullopt; } }); - ASSERT_TRUE(ret_expr.has_value()); - ASSERT_EQ(ret_expr.value(), dim_expr); + ASSERT_EQ(ret_expr, dim_expr); } TEST(DimExprUtil, MakeGetterDimExpr4SymbolName) { - std::vector> - symbol_bindings{}; - symbol_bindings.push_back(std::make_tuple("Symbol", 0, 0)); + cinn::dialect::GenerateShapeOp::SymbolBindings symbol_bindings{}; + using ShapeSymbolBinding = cinn::dialect::GenerateShapeOp::ShapeSymbolBinding; + symbol_bindings.emplace_back(ShapeSymbolBinding{"Symbol", 0, 0}); const auto& dim_expr = CreateExampleDimExpr(); + const auto& shape_or_data_dim_exprs = symbol::ShapeOrDataDimExprs({dim_expr}); const auto& DimExpr4SymbolName = MakeGetterDimExpr4SymbolName( symbol_bindings, - [dim_expr](int in_tensor_idx, - int in_tensor_dim_idx) -> std::optional { - if (in_tensor_idx == 0 && in_tensor_dim_idx == 0) { - return dim_expr; - } else { - return std::nullopt; - } + [&](int in_tensor_idx) -> const symbol::ShapeOrDataDimExprs& { + return shape_or_data_dim_exprs; }); const auto& opt_dim_expr = DimExpr4SymbolName("Symbol"); ASSERT_TRUE(opt_dim_expr.has_value()); From 5a2d1bac027a3c2fbb1e9040770e3ce4d904ebdb Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 9 Jan 2024 09:36:01 +0000 Subject: [PATCH 4/4] Fix MakeGetterDimExpr4SymbolName --- .../hlir/dialect/operator/ir/generate_shape_util.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index 3d88e575709ea..85726225284a3 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -350,6 +350,12 @@ std::optional GetDimExprBySymbolBindingImpl( return shape_or_data_dim_expr.shape().at(dim_idx); } +std::string GetSymbolNameBySymbolBinding( + const GenerateShapeOp::SymbolBinding& symbol_binding) { + return std::visit([](const auto& impl) { return impl.symbol_name; }, + symbol_binding); +} + } // namespace std::function(const std::string& symbol_name)> @@ -359,6 +365,10 @@ MakeGetterDimExpr4SymbolName( DimExpr4InputDim) { std::unordered_map> symbol_name2symbol_bindins{}; + for (const auto& symbol_binding : symbol_bindings) { + symbol_name2symbol_bindins[GetSymbolNameBySymbolBinding(symbol_binding)] + .emplace_back(symbol_binding); + } const auto& GetDimExpr = [&](const GenerateShapeOp::SymbolBinding& symbol_binding) { return std::visit(