From b6d78ef63a5bd6f2fc0f316a21325e719cc33874 Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Tue, 7 Nov 2023 05:04:27 +0000 Subject: [PATCH 1/8] fix shadow_output_op --- .../pir/transforms/constant_folding_pass.cc | 3 ++- .../transforms/dead_code_elimination_pass.cc | 8 ++++---- .../pattern_rewrite/pattern_rewrite_test.cc | 19 ++++++++++--------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 4e36f1df9defa8..1f6db5a76fe813 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -57,7 +57,8 @@ class ConstantFoldingPattern : public pir::RewritePattern { // TODO(liuyuanle): Use trait to improve robustness. if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() || op->isa<paddle::dialect::FetchOp>() || - op->isa<paddle::dialect::ShadowOutputOp>()) + op->isa<paddle::dialect::ShadowOutputOp>() || + op->isa<pir::ShadowOutputOp>()) return false; // Inputs must come from get parameter op. diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc index 7535ddeb513dbf..0760a9de420b82 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -35,16 +35,16 @@ class DeadCodeEliminationPattern : public pir::RewritePattern { } bool Match(pir::Operation* op) const override { - if (op->isa<paddle::dialect::FetchOp>() || - op->isa<paddle::dialect::ShadowOutputOp>()) + if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() || + op->isa<paddle::dialect::ShadowOutputOp>()) { return false; - + } return op->use_empty(); } void Rewrite(pir::Operation* op, pir::PatternRewriter& rewriter) const override { // NOLINT - if (op->dyn_cast<pir::GetParameterOp>()) { + if (op->isa<pir::GetParameterOp>()) { // Delete parameter from program. pir::GetParameterOp get_parameter_op = op->dyn_cast<pir::GetParameterOp>(); diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index d45a74f6fd0d10..28661875b3bafc 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -1126,15 +1126,16 @@ TEST(pattern_rewrite, Patterns) { pm.AddPass(pir::CreateConstantFoldingPass()); pm.AddPass(pir::CreateDeadCodeEliminationPass()); pm.EnablePassTiming(); - pm.EnableIRPrinting(std::make_unique<pir::PassManager::IRPrinterOption>( - [](pir::Pass *pass, pir::Operation *op) { - return pass->name() == "constant_folding_pass"; - }, - [](pir::Pass *pass, pir::Operation *op) { - return pass->name() == "constant_folding_pass"; - }, - true, - true)); + pm.EnableIRPrinting(); + // pm.EnableIRPrinting(std::make_unique<pir::PassManager::IRPrinterOption>( + // [](pir::Pass *pass, pir::Operation *op) { + // return pass->name() == "constant_folding_pass"; + // }, + // [](pir::Pass *pass, pir::Operation *op) { + // return pass->name() == "constant_folding_pass"; + // }, + // true, + // true)); CHECK_EQ(pm.Run(&program), true); EXPECT_EQ(program.block()->size(), 2u); From e7cd7beafe3700ed0391035ab986bff846c019de Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Tue, 7 Nov 2023 06:11:19 +0000 Subject: [PATCH 2/8] update --- test/ir/inference/test_inference_predictor_run.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/ir/inference/test_inference_predictor_run.py b/test/ir/inference/test_inference_predictor_run.py index 1c552bc82b77e8..1d8abc174f1cf1 100644 --- a/test/ir/inference/test_inference_predictor_run.py +++ b/test/ir/inference/test_inference_predictor_run.py @@ -62,8 +62,10 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + def enable_pir(self, flag: bool): + paddle.set_flags({'FLAGS_enable_pir_in_executor': flag}) + def init_predictor(self): - paddle.set_flags({'FLAGS_enable_pir_in_executor': True}) config = Config( os.path.join( self.temp_dir.name, @@ -115,12 +117,15 @@ def get_inorder_output(self, predictor): return outputs[0] def test_output(self): + self.enable_pir(False) predictor = self.init_predictor() - inorder_output = self.get_inorder_output(predictor) - disorder_output = self.get_disorder_output(predictor) + output = self.get_inorder_output(predictor) + self.enable_pir(True) + pir_predictor = self.init_predictor() + pir_output = self.get_disorder_output(pir_predictor) np.testing.assert_allclose( - inorder_output.numpy().flatten(), disorder_output.numpy().flatten() + output.numpy().flatten(), pir_output.numpy().flatten() ) From a30a22a72c7cc1fef747c63e829c7dee2d63a981 Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Wed, 8 Nov 2023 07:30:32 +0000 Subject: [PATCH 3/8] rewrite constant_folding_pass --- paddle/fluid/inference/api/CMakeLists.txt | 14 +- .../fluid/inference/api/analysis_predictor.cc | 4 +- .../pir/transforms/constant_folding_pass.cc | 173 +++++++++--------- .../pir/transforms/constant_folding_pass.h | 10 +- .../transforms/dead_code_elimination_pass.cc | 3 +- .../transforms/transform_general_functions.cc | 8 +- .../transforms/transform_general_functions.h | 6 +- .../pattern_rewrite/pattern_rewrite_test.cc | 18 +- 8 files changed, 122 insertions(+), 114 deletions(-) diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index f3fc183c910b8a..4aadb77aedc0e3 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -64,8 +64,18 @@ if(WIN32) target_link_libraries(paddle_inference_api phi) endif() -set(inference_deps ${analysis_deps} paddle_inference_api analysis - analysis_config naive_executor ${GLOB_PASS_LIB}) +set(PIR_PASS_DEPS + pd_constant_folding_pass dead_code_elimination_pass pd_op_to_kernel_pass + pd_inplace_pass replace_fetch_with_shadow_output_pass) + +set(inference_deps + ${analysis_deps} + paddle_inference_api + analysis + analysis_config + naive_executor + ${GLOB_PASS_LIB} + ${PIR_PASS_DEPS}) if(WITH_GPU AND TENSORRT_FOUND) set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 99b50c9b8ab28c..5f90955ec7cf3e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -103,6 +103,7 @@ #endif #include "paddle/fluid/ir_adaptor/translator/translate.h" +#include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" @@ -731,10 +732,11 @@ bool AnalysisPredictor::PrepareExecutor() { paddle::TranslateLegacyProgramToProgram(*inference_program_)); ::pir::PassManager pm(::pir::IrContext::Instance(), 2); + pm.AddPass(::pir::CreateConstantFoldingPass(place_, sub_scope_)); pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); pm.AddPass(::pir::CreateDeadCodeEliminationPass()); - pm.EnableIRPrinting(); + // pm.EnableIRPrinting(); pm.Run(pir_program_.get()); pir_program_ = std::move( diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 1f6db5a76fe813..66019b3e2c288c 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -17,20 +17,21 @@ #include <memory> #include <string> #include <unordered_map> - -// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in -// paddle/fluid/pir/dialect/CMakeLists.txt. -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include <vector> #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" + +#include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/op_result.h" @@ -46,22 +47,32 @@ namespace { class ConstantFoldingPattern : public pir::RewritePattern { public: - ConstantFoldingPattern(pir::IrContext* context, - paddle::framework::Scope* scope, - pir::PatternBenefit benefit = 1, - const std::vector<std::string>& generated_names = {}) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names), - scope_(scope) {} + ConstantFoldingPattern( + pir::IrContext* context, + size_t* suffix, + const phi::Place& place, + paddle::framework::Scope* scope, + paddle::framework::interpreter::ExecutionConfig* exe_config, + std::vector<std::string>* deleted_vars) + : RewritePattern(MatchAnyOpTypeTag(), + 1 /*benefit*/, + context, + {} /*generated_names*/), + counter_(suffix), + place_(place), + scope_(scope), + exe_config_(exe_config), + deleted_vars_(deleted_vars) { + exe_config_->create_local_scope = false; + } bool Match(pir::Operation* op) const override { - // TODO(liuyuanle): Use trait to improve robustness. if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() || - op->isa<paddle::dialect::FetchOp>() || - op->isa<paddle::dialect::ShadowOutputOp>() || - op->isa<pir::ShadowOutputOp>()) + op->isa<pir::ShadowOutputOp>() || op->isa<paddle::dialect::FetchOp>() || + op->isa<paddle::dialect::FeedOp>()) return false; - // Inputs must come from get parameter op. + // inputs must come from get parameter op for (uint32_t i = 0; i < op->num_operands(); ++i) if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>()) return false; @@ -70,73 +81,36 @@ class ConstantFoldingPattern : public pir::RewritePattern { void Rewrite(pir::Operation* op, pir::PatternRewriter& rewriter) const override { // NOLINT - pir::Program* program = op->GetParentProgram(); - auto temp_program = BuildProgramFromOperation(op); - - std::vector<std::string> fetch_var_names; - auto block = temp_program->block(); - for (auto it = block->begin(); it != block->end(); ++it) { - if ((*it)->isa<paddle::dialect::FetchOp>()) { - size_t index = (*it) - ->attributes() - .at("col") - .dyn_cast<pir::Int32Attribute>() - .data(); - - if (fetch_var_names.size() < index + 1) { - fetch_var_names.resize(index + 1); - } - - fetch_var_names[index] = (*it) - ->attributes() - .at("name") - .dyn_cast<pir::StrAttribute>() - .AsString() + - "@fetch"; - } - } + VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op"; + pir::Program new_program(ir_context()); + auto output_var_name = BuildProgramFromOperation(op, &new_program); - // Execute program - exe_config_.create_local_scope = false; + // execute program + exe_config_->skip_gc_vars.insert(output_var_name); auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(temp_program.get()); - paddle::framework::InterpreterCore core(phi::CPUPlace{}, - fetch_var_names, - kernel_program->block(), - scope_, - exe_config_); - - paddle::framework::FetchList fetch_list = core.Run({}); - - // TODO(liuyuanle): Support multiple output. - auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]); - std::unique_ptr<pir::Parameter> parameter = - std::make_unique<pir::Parameter>( - reinterpret_cast<void*>(out_tensor.data()), - out_tensor.numel() * phi::SizeOf(out_tensor.dtype()), - op->result(0).type()); - - std::string param_name = - "@constant_folding_pass@_" + std::to_string(suffix_++); - exe_config_.skip_gc_vars.insert(param_name); - - auto* param_var = scope_->Var(param_name); - auto* param_tensor = param_var->GetMutable<phi::DenseTensor>(); - *param_tensor = out_tensor; - program->SetParameter(param_name, std::move(parameter)); - // rewriter.SetInsertionPoint(op); - auto get_parameter_op = - rewriter.Build<pir::GetParameterOp>(param_name, op->result(0).type()); + paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); + paddle::framework::InterpreterCore core( + place_, {}, kernel_program->block(), scope_, *exe_config_); + + core.Run({}); + // TODO(liuyuanle): support multiple output + auto get_parameter_op = rewriter.Build<pir::GetParameterOp>( + output_var_name, op->result(0).type()); + get_parameter_op->set_attribute( + kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); + + VLOG(4) << "constant_folding_pass applied on [" << op->name() << "] op"; rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0)); rewriter.EraseOp(op); } private: - std::unique_ptr<pir::Program> BuildProgramFromOperation( - pir::Operation* op) const { - auto program = std::make_unique<pir::Program>(ir_context()); - pir::Builder builder = pir::Builder(ir_context(), program->block()); + std::string BuildProgramFromOperation(pir::Operation* op, + pir::Program* new_program) const { + pir::Builder builder = pir::Builder(ir_context(), new_program->block()); + std::string output_var_name = + "constant_folding@" + std::to_string((*counter_)++); // prepare op inputs std::vector<pir::Value> op_inputs; @@ -147,15 +121,14 @@ class ConstantFoldingPattern : public pir::RewritePattern { phi::errors::InvalidArgument( "Op's input must be a dense tensor type.")); - auto [param_name, param] = - pir::GetParameterFromValue(op->operand_source(i)); - program->SetParameter(param_name, - std::make_unique<pir::Parameter>(*param)); - + const auto& param_name = + pir::GetParameterNameFromValue(op->operand_source(i)); auto* param_var = scope_->FindVar(param_name); PADDLE_ENFORCE_NOT_NULL( param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); + output_var_name = output_var_name + "_" + param_name; + deleted_vars_->push_back(param_name); auto get_parameter_op = builder.Build<pir::GetParameterOp>( param_name, op->operand_source(i).type()); @@ -171,7 +144,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { auto* temp_op = builder.Build(op_inputs, op->attributes(), output_types, op->info()); - // TODO(liuyuanle): Support multiple output. + // TODO(liuyuanle): support multiple output // for (uint32_t i = 0; i < op->num_results(); i++) { PADDLE_ENFORCE_EQ( temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(), @@ -179,52 +152,72 @@ class ConstantFoldingPattern : public pir::RewritePattern { phi::errors::InvalidArgument( "Op's output must be a dense tensor type.")); - builder.Build<paddle::dialect::FetchOp>( - temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0); + builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name); // } - return program; + return output_var_name; } private: + size_t* counter_{nullptr}; + phi::Place place_; paddle::framework::Scope* scope_{nullptr}; - inline static size_t suffix_{0}; - inline static paddle::framework::interpreter::ExecutionConfig exe_config_{}; + paddle::framework::interpreter::ExecutionConfig* exe_config_{nullptr}; + std::vector<std::string>* deleted_vars_{nullptr}; }; class ConstantFoldingPass : public pir::Pass { public: - ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {} + ConstantFoldingPass(const phi::Place& place, paddle::framework::Scope* scope) + : pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) { + PADDLE_ENFORCE_NOT_NULL( + scope_, phi::errors::InvalidArgument("scope can not be nullptr")); + } + private: bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); - ps.Add<ConstantFoldingPattern>(context, &scope_); + ps.Add<ConstantFoldingPattern>( + context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } void Run(pir::Operation* op) override { + size_t op_nums = op->GetParentProgram()->block()->size(); pir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; cfg.max_iterations = 10; pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + + // delete old parameter var + scope_->EraseVars(deleted_vars_); + LOG(INFO) << " ------ constant_folding_pass done: [" << counter_ << "/" + << op_nums << "]"; } bool CanApplyOn(pir::Operation* op) const override { + // TODO(liuyuanle): remove op->isa<::pir::ModuleOp>() return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; } private: + size_t counter_{0}; + phi::Place place_; + paddle::framework::Scope* scope_{nullptr}; + paddle::framework::interpreter::ExecutionConfig exe_config_{}; + std::vector<std::string> deleted_vars_; + pir::FrozenRewritePatternSet patterns_; - paddle::framework::Scope scope_; }; } // namespace namespace pir { -std::unique_ptr<Pass> CreateConstantFoldingPass() { - return std::make_unique<ConstantFoldingPass>(); +std::unique_ptr<Pass> CreateConstantFoldingPass( + const phi::Place& place, paddle::framework::Scope* scope) { + return std::make_unique<ConstantFoldingPass>(place, scope); } } // namespace pir diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.h b/paddle/fluid/pir/transforms/constant_folding_pass.h index b49c9d90493b1c..0939ee589d448d 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.h +++ b/paddle/fluid/pir/transforms/constant_folding_pass.h @@ -15,12 +15,20 @@ #pragma once #include <memory> +#include "paddle/phi/common/place.h" #include "paddle/pir/core/dll_decl.h" +namespace paddle { +namespace framework { +class Scope; +} +} // namespace paddle + namespace pir { class Pass; -IR_API std::unique_ptr<Pass> CreateConstantFoldingPass(); +IR_API std::unique_ptr<Pass> CreateConstantFoldingPass( + const phi::Place& place, paddle::framework::Scope* scope); } // namespace pir diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc index 0760a9de420b82..bfde883cac67a9 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -35,8 +35,7 @@ class DeadCodeEliminationPattern : public pir::RewritePattern { } bool Match(pir::Operation* op) const override { - if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() || - op->isa<paddle::dialect::ShadowOutputOp>()) { + if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>()) { return false; } return op->use_empty(); diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index 8bd5028688b133..1d7c226197668f 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -22,8 +22,7 @@ namespace pir { -std::pair<std::string, pir::Parameter*> GetParameterFromValue( - pir::Value value) { +std::string GetParameterNameFromValue(pir::Value value) { pir::GetParameterOp op = value.dyn_cast<OpResult>().owner()->dyn_cast<pir::GetParameterOp>(); PADDLE_ENFORCE_NOT_NULL( @@ -37,10 +36,7 @@ std::pair<std::string, pir::Parameter*> GetParameterFromValue( .at(op.attributes_name[0]) .dyn_cast<pir::StrAttribute>() .AsString(); - pir::Parameter* param = program->GetParameter(name); - PADDLE_ENFORCE_NOT_NULL( - param, phi::errors::InvalidArgument("Parameter should not be null.")); - return {name, param}; + return name; } const phi::DDim& GetShapeFromValue(pir::Value value) { diff --git a/paddle/fluid/pir/transforms/transform_general_functions.h b/paddle/fluid/pir/transforms/transform_general_functions.h index 77c790235b8329..0d35ff776ce8c1 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.h +++ b/paddle/fluid/pir/transforms/transform_general_functions.h @@ -25,16 +25,16 @@ namespace pir { /** - * @brief Get the [name, parameter] pair of pararmeter from a value. + * @brief Get the name of pararmeter from a value. * * @note The value must be a output of a GetParameterOp. * * @param pir::Value * - * @return std::pair<std::string, pir::Parameter*> + * @return std::string */ -std::pair<std::string, pir::Parameter*> GetParameterFromValue(pir::Value value); +std::string GetParameterNameFromValue(pir::Value value); /** * @brief Get tensor's shape from a value. diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 28661875b3bafc..df112a6a3f44c3 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -20,11 +20,15 @@ #include <sstream> #include <vector> +#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/phi/core/kernel_registry.h" + #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_dialect.h" @@ -44,13 +48,9 @@ #include "paddle/pir/pattern_rewrite/pattern_match.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in -// paddle/fluid/pir/dialect/CMakeLists.txt. -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" - -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" // build Conv2dFusionOp #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" @@ -1120,10 +1120,10 @@ TEST(pattern_rewrite, Patterns) { BuildProgram(builder); EXPECT_EQ(program.block()->size(), 11u); - + paddle::framework::Scope scope; pir::PassManager pm(ctx); pm.AddPass(std::make_unique<TestPass>()); - pm.AddPass(pir::CreateConstantFoldingPass()); + pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope)); pm.AddPass(pir::CreateDeadCodeEliminationPass()); pm.EnablePassTiming(); pm.EnableIRPrinting(); From 1185d56e67de7c16246737afb3b532a62b334f2f Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Wed, 8 Nov 2023 08:00:34 +0000 Subject: [PATCH 4/8] update --- paddle/fluid/pir/transforms/constant_folding_pass.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 66019b3e2c288c..5934294769752e 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -109,8 +109,6 @@ class ConstantFoldingPattern : public pir::RewritePattern { std::string BuildProgramFromOperation(pir::Operation* op, pir::Program* new_program) const { pir::Builder builder = pir::Builder(ir_context(), new_program->block()); - std::string output_var_name = - "constant_folding@" + std::to_string((*counter_)++); // prepare op inputs std::vector<pir::Value> op_inputs; @@ -127,7 +125,6 @@ class ConstantFoldingPattern : public pir::RewritePattern { PADDLE_ENFORCE_NOT_NULL( param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); - output_var_name = output_var_name + "_" + param_name; deleted_vars_->push_back(param_name); auto get_parameter_op = builder.Build<pir::GetParameterOp>( @@ -152,6 +149,10 @@ class ConstantFoldingPattern : public pir::RewritePattern { phi::errors::InvalidArgument( "Op's output must be a dense tensor type.")); + std::stringstream ss; + ss << std::chrono::high_resolution_clock::now().time_since_epoch().count(); + std::string output_var_name = ss.str() + std::to_string((*counter_)++); + builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name); // } From 700ebc71e1e7b1641f5038c52a763a4e20cf11a6 Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Wed, 8 Nov 2023 08:07:05 +0000 Subject: [PATCH 5/8] update --- paddle/fluid/pir/transforms/constant_folding_pass.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 5934294769752e..fb3d7de7331f38 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -151,7 +151,8 @@ class ConstantFoldingPattern : public pir::RewritePattern { std::stringstream ss; ss << std::chrono::high_resolution_clock::now().time_since_epoch().count(); - std::string output_var_name = ss.str() + std::to_string((*counter_)++); + std::string output_var_name = + "constant_folding@_" + ss.str() + std::to_string((*counter_)++); builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name); // } From a6900b3207edeee67b2c9f8a806426b1af13e812 Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Wed, 8 Nov 2023 08:53:56 +0000 Subject: [PATCH 6/8] fix compile --- paddle/fluid/inference/api/CMakeLists.txt | 6 +---- .../pir/transforms/constant_folding_pass.cc | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 4aadb77aedc0e3..f15bd26c4476a6 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -64,10 +64,6 @@ if(WIN32) target_link_libraries(paddle_inference_api phi) endif() -set(PIR_PASS_DEPS - pd_constant_folding_pass dead_code_elimination_pass pd_op_to_kernel_pass - pd_inplace_pass replace_fetch_with_shadow_output_pass) - set(inference_deps ${analysis_deps} paddle_inference_api @@ -75,7 +71,7 @@ set(inference_deps analysis_config naive_executor ${GLOB_PASS_LIB} - ${PIR_PASS_DEPS}) + transform) if(WITH_GPU AND TENSORRT_FOUND) set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index fb3d7de7331f38..a89caad7d65ff1 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -113,11 +113,11 @@ class ConstantFoldingPattern : public pir::RewritePattern { // prepare op inputs std::vector<pir::Value> op_inputs; for (uint32_t i = 0; i < op->num_operands(); i++) { - PADDLE_ENFORCE_EQ( - op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(), - true, - phi::errors::InvalidArgument( - "Op's input must be a dense tensor type.")); + // PADDLE_ENFORCE_EQ( + // op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(), + // true, + // phi::errors::InvalidArgument( + // "Op [%s] input must be a dense tensor type.", op->name())); const auto& param_name = pir::GetParameterNameFromValue(op->operand_source(i)); @@ -125,7 +125,9 @@ class ConstantFoldingPattern : public pir::RewritePattern { PADDLE_ENFORCE_NOT_NULL( param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); - deleted_vars_->push_back(param_name); + if (op->operand_source(i).use_count() == 1) { + deleted_vars_->push_back(param_name); + } auto get_parameter_op = builder.Build<pir::GetParameterOp>( param_name, op->operand_source(i).type()); @@ -143,11 +145,11 @@ class ConstantFoldingPattern : public pir::RewritePattern { // TODO(liuyuanle): support multiple output // for (uint32_t i = 0; i < op->num_results(); i++) { - PADDLE_ENFORCE_EQ( - temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(), - true, - phi::errors::InvalidArgument( - "Op's output must be a dense tensor type.")); + // PADDLE_ENFORCE_EQ( + // temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(), + // true, + // phi::errors::InvalidArgument( + // "Op [%s] output must be a dense tensor type.", temp_op->name())); std::stringstream ss; ss << std::chrono::high_resolution_clock::now().time_since_epoch().count(); From 400762de23d89c20380799590c9d05df21a5ffaa Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Wed, 8 Nov 2023 11:14:27 +0000 Subject: [PATCH 7/8] fix dce --- paddle/fluid/pir/transforms/dead_code_elimination_pass.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc index bfde883cac67a9..9c6fcd9b3d9ca4 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" @@ -35,7 +36,8 @@ class DeadCodeEliminationPattern : public pir::RewritePattern { } bool Match(pir::Operation* op) const override { - if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>()) { + if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() || + op->isa<pir::YieldOp>()) { return false; } return op->use_empty(); From b4e4a1c5a7b8d8425e36adfbfc644eeb7927a099 Mon Sep 17 00:00:00 2001 From: yuanlehome <yuanlehome@163.com> Date: Wed, 8 Nov 2023 12:04:42 +0000 Subject: [PATCH 8/8] enhance judgement --- .../pir/transforms/constant_folding_pass.cc | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index a89caad7d65ff1..39daebc1a3b8f9 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -72,10 +72,10 @@ class ConstantFoldingPattern : public pir::RewritePattern { op->isa<paddle::dialect::FeedOp>()) return false; - // inputs must come from get parameter op - for (uint32_t i = 0; i < op->num_operands(); ++i) - if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>()) - return false; + if (!ValidOp(op)) { + return false; + } + return true; } @@ -106,6 +106,26 @@ class ConstantFoldingPattern : public pir::RewritePattern { } private: + bool ValidOp(pir::Operation* op) const { + for (uint32_t i = 0; i < op->num_operands(); i++) { + // 1. inputs must come from get_parameter op + // 2. inputs must be a dense tensor type + if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>() || + !op->operand_source(i) + .type() + .isa<paddle::dialect::DenseTensorType>()) { + return false; + } + // 3. outputs must be a dense tensor type + for (uint32_t i = 0; i < op->num_results(); i++) { + if (!op->result(i).type().isa<paddle::dialect::DenseTensorType>()) { + return false; + } + } + } + return true; + } + std::string BuildProgramFromOperation(pir::Operation* op, pir::Program* new_program) const { pir::Builder builder = pir::Builder(ir_context(), new_program->block()); @@ -113,12 +133,6 @@ class ConstantFoldingPattern : public pir::RewritePattern { // prepare op inputs std::vector<pir::Value> op_inputs; for (uint32_t i = 0; i < op->num_operands(); i++) { - // PADDLE_ENFORCE_EQ( - // op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(), - // true, - // phi::errors::InvalidArgument( - // "Op [%s] input must be a dense tensor type.", op->name())); - const auto& param_name = pir::GetParameterNameFromValue(op->operand_source(i)); auto* param_var = scope_->FindVar(param_name); @@ -145,11 +159,6 @@ class ConstantFoldingPattern : public pir::RewritePattern { // TODO(liuyuanle): support multiple output // for (uint32_t i = 0; i < op->num_results(); i++) { - // PADDLE_ENFORCE_EQ( - // temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(), - // true, - // phi::errors::InvalidArgument( - // "Op [%s] output must be a dense tensor type.", temp_op->name())); std::stringstream ss; ss << std::chrono::high_resolution_clock::now().time_since_epoch().count();