From adcc73c7923162957ef7afdf6b106dc50dc0ffb0 Mon Sep 17 00:00:00 2001 From: gongshaotian <> Date: Fri, 25 Aug 2023 06:22:57 +0000 Subject: [PATCH 1/3] Add fusion testing code for fullOp and expandOp --- paddle/fluid/ir/drr/ir_operation_creator.cc | 14 +++++ paddle/fluid/ir/drr/match_context_impl.h | 23 ++++++++ test/cpp/ir/pattern_rewrite/drr_test.cc | 64 +++++++++++++-------- 3 files changed, 77 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/ir/drr/ir_operation_creator.cc b/paddle/fluid/ir/drr/ir_operation_creator.cc index 729a47cc3a691e..44dd90e09ed427 100644 --- a/paddle/fluid/ir/drr/ir_operation_creator.cc +++ b/paddle/fluid/ir/drr/ir_operation_creator.cc @@ -70,6 +70,7 @@ Operation* CreateOperation(const OpCall& op_call, op_call.outputs()[1]->name(), std::make_shared(reshape_op->result(1))); return reshape_op; + } else if (op_call.name() == "pd.transpose") { const auto& inputs = op_call.inputs(); std::vector ir_values = @@ -81,6 +82,7 @@ Operation* CreateOperation(const OpCall& op_call, op_call.outputs()[0]->name(), std::make_shared(transpose_op->result(0))); return transpose_op; + } else if (op_call.name() == "pd.cast") { const auto& inputs = op_call.inputs(); std::vector ir_values = @@ -91,6 +93,18 @@ Operation* CreateOperation(const OpCall& op_call, res_match_ctx->BindIrValue(op_call.outputs()[0]->name(), std::make_shared(cast_op->result(0))); return cast_op; + + } else if (op_call.name() == "pd.full") { + const auto& inputs = op_call.inputs(); + std::vector ir_values = + GetIrValuesByDrrTensors(inputs, *res_match_ctx); + Operation* full_op = rewriter.Build( + CreateAttributeMap(op_call, src_match_ctx) + ); + res_match_ctx->BindIrValue( + op_call.outputs()[0]->name(), + std::make_shared(full_op->result(0))); + return full_op; } LOG(ERROR) << "Unknown op " << op_call.name(); diff --git a/paddle/fluid/ir/drr/match_context_impl.h b/paddle/fluid/ir/drr/match_context_impl.h index 6a184e6d45527c..ef523101eed868 100644 --- a/paddle/fluid/ir/drr/match_context_impl.h +++ b/paddle/fluid/ir/drr/match_context_impl.h @@ -42,6 +42,7 @@ PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute); PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, paddle::dialect::DataTypeAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); template struct IrAttrTypeCast { @@ -62,6 +63,28 @@ struct IrAttrTypeCast> { } }; +template <> +struct IrAttrTypeCast>{ + static std::vector To(const ir::Attribute& attr){ + std::vector result; + auto array_attr = attr.dyn_cast(); + + if (array_attr) { + for(size_t i = 0; i < array_attr.size(); i++){ + result.push_back(array_attr.at(i).dyn_cast().data()); + } + return result; + } + + else if (attr.dyn_cast()) { + result = attr.dyn_cast().data().GetData(); + return result; + } + LOG(ERROR) << "Dynamic cast failed for IR attribute vector"; + IR_THROW(); + } +}; + class MatchContextImpl final { public: MatchContextImpl() = default; diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index b3803f93df0a20..d9a30f342ac06d 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -45,31 +45,37 @@ class RemoveRedundentReshapePattern } }; -class FoldBroadcastToConstantPattern - : public ir::drr::DrrPatternBase { +class RemoveExpandPattern + : public ir::drr::DrrPatternBase { public: void operator()(ir::drr::DrrPatternContext *ctx) const override { + // Source Pattern 中可匹配的类型包括 Op 和 Tensor ir::drr::SourcePattern pat = ctx->SourcePattern(); - // Source Pattern 中可匹配的类型包括 Op 和 Tensor - const auto &fill_constant = pat.Op( - "pd.fill_constant", - {{"value", pat.Attr("value_1")}, {"dtype", pat.Attr("dtype_1")}}); - const auto &broadcast_to = - pat.Op("pd.expand", {{"shape", pat.Attr("shape_1")}}); - // 匹配fill_constant+broadcast_to,同时对输出张量标记为ret,方便后面加约束 - pat.Tensor("ret") = broadcast_to(fill_constant()); - // Constrains:本Pass无额外的约束规则 - // Result patterns:要替换为的子图 + const auto &full1 = pat.Op( + "pd.full", + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}} + ); + const auto &full_int_array1 = pat.Op( + "pd.full_int_array", + {{"value", pat.Attr("expand_shape_value")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}} + ); + const auto &expand = pat.Op("pd.expand"); + pat.Tensor("ret") = expand(full1(), full_int_array1()); + + // Result patterns:要替换为的子图. Constrains: 本Pass无额外约束规则 ir::drr::ResultPattern res = pat.ResultPattern(); - // 使用 folded_fill_constant 替换 - // broadcast_to(fill_constant()),注意shape属性已更新 所有 ret - // 参数均在Source Pattern中使用,对 ret 的赋值等同于对 ret 的 producer - // op的删除和重连接 - const auto &folded_fill_constant = res.Op("pd.fill_constant", - {{"shape", res.Attr("shape_1")}, - {"value", res.Attr("value_1")}, - {"dtype", res.Attr("dtype_1")}}); - res.Tensor("ret") = folded_fill_constant(); + const auto &full2 = res.Op("pd.full", + {{"shape", pat.Attr("expand_shape_value")}, // full_int_array的value为expand的shape + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}} + ); + res.Tensor("ret") = full2(); } }; @@ -124,14 +130,23 @@ class RemoveUselessCastPattern void BuildProgram(ir::Builder &builder) { // NOLINT paddle::dialect::FullOp full_input_op = - builder.Build(std::vector{4, 3, 16, 16}, + builder.Build(std::vector{4, 3, 16}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::FullIntArrayOp full_int_array_op = + builder.Build(std::vector{4, 3, 16, 16}, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::ExpandOp expand_op = + builder.Build( + full_input_op.out(), full_int_array_op.out()); + paddle::dialect::ReshapeOp reshape_op1 = builder.Build( - full_input_op.out(), std::vector{16, 3, 4, 16}); + expand_op.out(), std::vector{16, 3, 4, 16}); paddle::dialect::ReshapeOp reshape_op2 = builder.Build( @@ -170,6 +185,7 @@ class DrrPatternRewritePass : public ir::Pass { ps.Add(RemoveRedundentTransposePattern().Build(context)); ps.Add(RemoveRedundentCastPattern().Build(context)); ps.Add(RemoveUselessCastPattern().Build(context)); + ps.Add(RemoveExpandPattern().Build(context)); patterns_ = ir::FrozenRewritePatternSet(std::move(ps)); return true; @@ -197,7 +213,7 @@ TEST(DrrTest, drr) { ir::Builder builder = ir::Builder(ctx, program.block()); BuildProgram(builder); - EXPECT_EQ(program.block()->size(), 12u); + EXPECT_EQ(program.block()->size(), 14u); ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); From 35c9390343e221aa657ea993c7ef350f0b7f56bd Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Fri, 25 Aug 2023 11:17:58 +0000 Subject: [PATCH 2/3] Standardize code format --- paddle/fluid/ir/drr/match_context_impl.h | 24 +++++----- test/cpp/ir/pattern_rewrite/drr_test.cc | 57 +++++++++++------------- 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/ir/drr/match_context_impl.h b/paddle/fluid/ir/drr/match_context_impl.h index ef523101eed868..2e6528d9aa50bf 100644 --- a/paddle/fluid/ir/drr/match_context_impl.h +++ b/paddle/fluid/ir/drr/match_context_impl.h @@ -64,24 +64,22 @@ struct IrAttrTypeCast> { }; template <> -struct IrAttrTypeCast>{ - static std::vector To(const ir::Attribute& attr){ +struct IrAttrTypeCast> { + static std::vector To(const ir::Attribute& attr) { std::vector result; - auto array_attr = attr.dyn_cast(); - - if (array_attr) { - for(size_t i = 0; i < array_attr.size(); i++){ - result.push_back(array_attr.at(i).dyn_cast().data()); + if (attr.dyn_cast()) { + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back( + array_attr.at(i).dyn_cast().data()); } return result; - } - - else if (attr.dyn_cast()) { - result = attr.dyn_cast().data().GetData(); + } else if (attr.dyn_cast()) { + result = + attr.dyn_cast().data().GetData(); return result; } - LOG(ERROR) << "Dynamic cast failed for IR attribute vector"; - IR_THROW(); + IR_THROW("Dynamic cast failed for IR attribute vector"); } }; diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index d9a30f342ac06d..d1adca21f8435b 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -45,36 +45,32 @@ class RemoveRedundentReshapePattern } }; -class RemoveExpandPattern - : public ir::drr::DrrPatternBase { +class FoldExpandToConstantPattern + : public ir::drr::DrrPatternBase { public: void operator()(ir::drr::DrrPatternContext *ctx) const override { - // Source Pattern 中可匹配的类型包括 Op 和 Tensor + // Source Pattern 中可匹配的类型包括 Op 和 Tensor ir::drr::SourcePattern pat = ctx->SourcePattern(); - const auto &full1 = pat.Op( - "pd.full", - {{"shape", pat.Attr("shape_1")}, - {"value", pat.Attr("value_1")}, - {"dtype", pat.Attr("dtype_1")}, - {"place", pat.Attr("place_1")}} - ); - const auto &full_int_array1 = pat.Op( - "pd.full_int_array", - {{"value", pat.Attr("expand_shape_value")}, - {"dtype", pat.Attr("dtype_2")}, - {"place", pat.Attr("place_2")}} - ); + const auto &full1 = pat.Op("pd.full", + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &full_int_array1 = + pat.Op("pd.full_int_array", + {{"value", pat.Attr("expand_shape_value")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}}); const auto &expand = pat.Op("pd.expand"); pat.Tensor("ret") = expand(full1(), full_int_array1()); - + // Result patterns:要替换为的子图. Constrains: 本Pass无额外约束规则 ir::drr::ResultPattern res = pat.ResultPattern(); const auto &full2 = res.Op("pd.full", - {{"shape", pat.Attr("expand_shape_value")}, // full_int_array的value为expand的shape - {"value", pat.Attr("value_1")}, - {"dtype", pat.Attr("dtype_1")}, - {"place", pat.Attr("place_1")}} - ); + {{"shape", pat.Attr("expand_shape_value")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); res.Tensor("ret") = full2(); } }; @@ -135,14 +131,15 @@ void BuildProgram(ir::Builder &builder) { // NOLINT phi::DataType::FLOAT32, phi::CPUPlace()); - paddle::dialect::FullIntArrayOp full_int_array_op = - builder.Build(std::vector{4, 3, 16, 16}, - phi::DataType::FLOAT32, - phi::CPUPlace()); + paddle::dialect::FullIntArrayOp full_int_array_op = + builder.Build( + std::vector{4, 3, 16, 16}, + phi::DataType::FLOAT32, + phi::CPUPlace()); - paddle::dialect::ExpandOp expand_op = - builder.Build( - full_input_op.out(), full_int_array_op.out()); + paddle::dialect::ExpandOp expand_op = + builder.Build(full_input_op.out(), + full_int_array_op.out()); paddle::dialect::ReshapeOp reshape_op1 = builder.Build( @@ -185,7 +182,7 @@ class DrrPatternRewritePass : public ir::Pass { ps.Add(RemoveRedundentTransposePattern().Build(context)); ps.Add(RemoveRedundentCastPattern().Build(context)); ps.Add(RemoveUselessCastPattern().Build(context)); - ps.Add(RemoveExpandPattern().Build(context)); + ps.Add(FoldExpandToConstantPattern().Build(context)); patterns_ = ir::FrozenRewritePatternSet(std::move(ps)); return true; From da03ecbfa1555b7ce69903de913c5c9fdc19093f Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Fri, 25 Aug 2023 11:34:22 +0000 Subject: [PATCH 3/3] Replace IR_THROW() with PADDLE_THROW() --- paddle/fluid/ir/drr/ir_operation_creator.cc | 15 ++++++--------- paddle/fluid/ir/drr/match_context_impl.h | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/ir/drr/ir_operation_creator.cc b/paddle/fluid/ir/drr/ir_operation_creator.cc index 44dd90e09ed427..170db184e164c4 100644 --- a/paddle/fluid/ir/drr/ir_operation_creator.cc +++ b/paddle/fluid/ir/drr/ir_operation_creator.cc @@ -70,7 +70,7 @@ Operation* CreateOperation(const OpCall& op_call, op_call.outputs()[1]->name(), std::make_shared(reshape_op->result(1))); return reshape_op; - + } else if (op_call.name() == "pd.transpose") { const auto& inputs = op_call.inputs(); std::vector ir_values = @@ -96,19 +96,16 @@ Operation* CreateOperation(const OpCall& op_call, } else if (op_call.name() == "pd.full") { const auto& inputs = op_call.inputs(); - std::vector ir_values = + std::vector ir_values = GetIrValuesByDrrTensors(inputs, *res_match_ctx); Operation* full_op = rewriter.Build( - CreateAttributeMap(op_call, src_match_ctx) - ); - res_match_ctx->BindIrValue( - op_call.outputs()[0]->name(), - std::make_shared(full_op->result(0))); + CreateAttributeMap(op_call, src_match_ctx)); + res_match_ctx->BindIrValue(op_call.outputs()[0]->name(), + std::make_shared(full_op->result(0))); return full_op; } - LOG(ERROR) << "Unknown op " << op_call.name(); - return nullptr; + PADDLE_THROW("Unknown op :" + op_call.name()); } } // namespace drr diff --git a/paddle/fluid/ir/drr/match_context_impl.h b/paddle/fluid/ir/drr/match_context_impl.h index 2e6528d9aa50bf..a7afad599281b9 100644 --- a/paddle/fluid/ir/drr/match_context_impl.h +++ b/paddle/fluid/ir/drr/match_context_impl.h @@ -79,7 +79,7 @@ struct IrAttrTypeCast> { attr.dyn_cast().data().GetData(); return result; } - IR_THROW("Dynamic cast failed for IR attribute vector"); + PADDLE_THROW("Dynamic cast failed for IR attribute vector"); } };