From b6d78ef63a5bd6f2fc0f316a21325e719cc33874 Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Tue, 7 Nov 2023 05:04:27 +0000
Subject: [PATCH 1/8] fix shadow_output_op

---
 .../pir/transforms/constant_folding_pass.cc   |  3 ++-
 .../transforms/dead_code_elimination_pass.cc  |  8 ++++----
 .../pattern_rewrite/pattern_rewrite_test.cc   | 19 ++++++++++---------
 3 files changed, 16 insertions(+), 14 deletions(-)

diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc
index 4e36f1df9defa8..1f6db5a76fe813 100644
--- a/paddle/fluid/pir/transforms/constant_folding_pass.cc
+++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc
@@ -57,7 +57,8 @@ class ConstantFoldingPattern : public pir::RewritePattern {
     // TODO(liuyuanle): Use trait to improve robustness.
     if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() ||
         op->isa<paddle::dialect::FetchOp>() ||
-        op->isa<paddle::dialect::ShadowOutputOp>())
+        op->isa<paddle::dialect::ShadowOutputOp>() ||
+        op->isa<pir::ShadowOutputOp>())
       return false;
 
     // Inputs must come from get parameter op.
diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
index 7535ddeb513dbf..0760a9de420b82 100644
--- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
+++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
@@ -35,16 +35,16 @@ class DeadCodeEliminationPattern : public pir::RewritePattern {
   }
 
   bool Match(pir::Operation* op) const override {
-    if (op->isa<paddle::dialect::FetchOp>() ||
-        op->isa<paddle::dialect::ShadowOutputOp>())
+    if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() ||
+        op->isa<paddle::dialect::ShadowOutputOp>()) {
       return false;
-
+    }
     return op->use_empty();
   }
 
   void Rewrite(pir::Operation* op,
                pir::PatternRewriter& rewriter) const override {  // NOLINT
-    if (op->dyn_cast<pir::GetParameterOp>()) {
+    if (op->isa<pir::GetParameterOp>()) {
       // Delete parameter from program.
       pir::GetParameterOp get_parameter_op =
           op->dyn_cast<pir::GetParameterOp>();
diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
index d45a74f6fd0d10..28661875b3bafc 100644
--- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
+++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
@@ -1126,15 +1126,16 @@ TEST(pattern_rewrite, Patterns) {
   pm.AddPass(pir::CreateConstantFoldingPass());
   pm.AddPass(pir::CreateDeadCodeEliminationPass());
   pm.EnablePassTiming();
-  pm.EnableIRPrinting(std::make_unique<pir::PassManager::IRPrinterOption>(
-      [](pir::Pass *pass, pir::Operation *op) {
-        return pass->name() == "constant_folding_pass";
-      },
-      [](pir::Pass *pass, pir::Operation *op) {
-        return pass->name() == "constant_folding_pass";
-      },
-      true,
-      true));
+  pm.EnableIRPrinting();
+  //   pm.EnableIRPrinting(std::make_unique<pir::PassManager::IRPrinterOption>(
+  //       [](pir::Pass *pass, pir::Operation *op) {
+  //         return pass->name() == "constant_folding_pass";
+  //       },
+  //       [](pir::Pass *pass, pir::Operation *op) {
+  //         return pass->name() == "constant_folding_pass";
+  //       },
+  //       true,
+  //       true));
 
   CHECK_EQ(pm.Run(&program), true);
   EXPECT_EQ(program.block()->size(), 2u);

From e7cd7beafe3700ed0391035ab986bff846c019de Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Tue, 7 Nov 2023 06:11:19 +0000
Subject: [PATCH 2/8] update

---
 test/ir/inference/test_inference_predictor_run.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/test/ir/inference/test_inference_predictor_run.py b/test/ir/inference/test_inference_predictor_run.py
index 1c552bc82b77e8..1d8abc174f1cf1 100644
--- a/test/ir/inference/test_inference_predictor_run.py
+++ b/test/ir/inference/test_inference_predictor_run.py
@@ -62,8 +62,10 @@ def setUp(self):
     def tearDown(self):
         self.temp_dir.cleanup()
 
+    def enable_pir(self, flag: bool):
+        paddle.set_flags({'FLAGS_enable_pir_in_executor': flag})
+
     def init_predictor(self):
-        paddle.set_flags({'FLAGS_enable_pir_in_executor': True})
         config = Config(
             os.path.join(
                 self.temp_dir.name,
@@ -115,12 +117,15 @@ def get_inorder_output(self, predictor):
         return outputs[0]
 
     def test_output(self):
+        self.enable_pir(False)
         predictor = self.init_predictor()
-        inorder_output = self.get_inorder_output(predictor)
-        disorder_output = self.get_disorder_output(predictor)
+        output = self.get_inorder_output(predictor)
+        self.enable_pir(True)
+        pir_predictor = self.init_predictor()
+        pir_output = self.get_disorder_output(pir_predictor)
 
         np.testing.assert_allclose(
-            inorder_output.numpy().flatten(), disorder_output.numpy().flatten()
+            output.numpy().flatten(), pir_output.numpy().flatten()
         )
 
 

From a30a22a72c7cc1fef747c63e829c7dee2d63a981 Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Wed, 8 Nov 2023 07:30:32 +0000
Subject: [PATCH 3/8] rewrite constant_folding_pass

---
 paddle/fluid/inference/api/CMakeLists.txt     |  14 +-
 .../fluid/inference/api/analysis_predictor.cc |   4 +-
 .../pir/transforms/constant_folding_pass.cc   | 173 +++++++++---------
 .../pir/transforms/constant_folding_pass.h    |  10 +-
 .../transforms/dead_code_elimination_pass.cc  |   3 +-
 .../transforms/transform_general_functions.cc |   8 +-
 .../transforms/transform_general_functions.h  |   6 +-
 .../pattern_rewrite/pattern_rewrite_test.cc   |  18 +-
 8 files changed, 122 insertions(+), 114 deletions(-)

diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt
index f3fc183c910b8a..4aadb77aedc0e3 100755
--- a/paddle/fluid/inference/api/CMakeLists.txt
+++ b/paddle/fluid/inference/api/CMakeLists.txt
@@ -64,8 +64,18 @@ if(WIN32)
   target_link_libraries(paddle_inference_api phi)
 endif()
 
-set(inference_deps ${analysis_deps} paddle_inference_api analysis
-                   analysis_config naive_executor ${GLOB_PASS_LIB})
+set(PIR_PASS_DEPS
+    pd_constant_folding_pass dead_code_elimination_pass pd_op_to_kernel_pass
+    pd_inplace_pass replace_fetch_with_shadow_output_pass)
+
+set(inference_deps
+    ${analysis_deps}
+    paddle_inference_api
+    analysis
+    analysis_config
+    naive_executor
+    ${GLOB_PASS_LIB}
+    ${PIR_PASS_DEPS})
 
 if(WITH_GPU AND TENSORRT_FOUND)
   set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter)
diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc
index 99b50c9b8ab28c..5f90955ec7cf3e 100644
--- a/paddle/fluid/inference/api/analysis_predictor.cc
+++ b/paddle/fluid/inference/api/analysis_predictor.cc
@@ -103,6 +103,7 @@
 #endif
 
 #include "paddle/fluid/ir_adaptor/translator/translate.h"
+#include "paddle/fluid/pir/transforms/constant_folding_pass.h"
 #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
 #include "paddle/fluid/pir/transforms/inplace_pass.h"
 #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
@@ -731,10 +732,11 @@ bool AnalysisPredictor::PrepareExecutor() {
           paddle::TranslateLegacyProgramToProgram(*inference_program_));
 
       ::pir::PassManager pm(::pir::IrContext::Instance(), 2);
+      pm.AddPass(::pir::CreateConstantFoldingPass(place_, sub_scope_));
       pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
       pm.AddPass(::pir::CreateDeadCodeEliminationPass());
 
-      pm.EnableIRPrinting();
+      // pm.EnableIRPrinting();
       pm.Run(pir_program_.get());
 
       pir_program_ = std::move(
diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc
index 1f6db5a76fe813..66019b3e2c288c 100644
--- a/paddle/fluid/pir/transforms/constant_folding_pass.cc
+++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc
@@ -17,20 +17,21 @@
 #include <memory>
 #include <string>
 #include <unordered_map>
-
-// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
-// paddle/fluid/pir/dialect/CMakeLists.txt.
-#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
+#include <vector>
 
 #include "paddle/fluid/framework/new_executor/interpretercore.h"
 #include "paddle/fluid/framework/scope.h"
 #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
 #include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
+#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
 #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
 #include "paddle/fluid/pir/transforms/transform_general_functions.h"
+
 #include "paddle/phi/common/place.h"
 #include "paddle/phi/core/dense_tensor.h"
 #include "paddle/phi/core/enforce.h"
+
+#include "paddle/pir/core/builtin_attribute.h"
 #include "paddle/pir/core/builtin_op.h"
 #include "paddle/pir/core/ir_context.h"
 #include "paddle/pir/core/op_result.h"
@@ -46,22 +47,32 @@ namespace {
 
 class ConstantFoldingPattern : public pir::RewritePattern {
  public:
-  ConstantFoldingPattern(pir::IrContext* context,
-                         paddle::framework::Scope* scope,
-                         pir::PatternBenefit benefit = 1,
-                         const std::vector<std::string>& generated_names = {})
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names),
-        scope_(scope) {}
+  ConstantFoldingPattern(
+      pir::IrContext* context,
+      size_t* suffix,
+      const phi::Place& place,
+      paddle::framework::Scope* scope,
+      paddle::framework::interpreter::ExecutionConfig* exe_config,
+      std::vector<std::string>* deleted_vars)
+      : RewritePattern(MatchAnyOpTypeTag(),
+                       1 /*benefit*/,
+                       context,
+                       {} /*generated_names*/),
+        counter_(suffix),
+        place_(place),
+        scope_(scope),
+        exe_config_(exe_config),
+        deleted_vars_(deleted_vars) {
+    exe_config_->create_local_scope = false;
+  }
 
   bool Match(pir::Operation* op) const override {
-    // TODO(liuyuanle): Use trait to improve robustness.
     if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() ||
-        op->isa<paddle::dialect::FetchOp>() ||
-        op->isa<paddle::dialect::ShadowOutputOp>() ||
-        op->isa<pir::ShadowOutputOp>())
+        op->isa<pir::ShadowOutputOp>() || op->isa<paddle::dialect::FetchOp>() ||
+        op->isa<paddle::dialect::FeedOp>())
       return false;
 
-    // Inputs must come from get parameter op.
+    // inputs must come from get parameter op
     for (uint32_t i = 0; i < op->num_operands(); ++i)
       if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>())
         return false;
@@ -70,73 +81,36 @@ class ConstantFoldingPattern : public pir::RewritePattern {
 
   void Rewrite(pir::Operation* op,
                pir::PatternRewriter& rewriter) const override {  // NOLINT
-    pir::Program* program = op->GetParentProgram();
-    auto temp_program = BuildProgramFromOperation(op);
-
-    std::vector<std::string> fetch_var_names;
-    auto block = temp_program->block();
-    for (auto it = block->begin(); it != block->end(); ++it) {
-      if ((*it)->isa<paddle::dialect::FetchOp>()) {
-        size_t index = (*it)
-                           ->attributes()
-                           .at("col")
-                           .dyn_cast<pir::Int32Attribute>()
-                           .data();
-
-        if (fetch_var_names.size() < index + 1) {
-          fetch_var_names.resize(index + 1);
-        }
-
-        fetch_var_names[index] = (*it)
-                                     ->attributes()
-                                     .at("name")
-                                     .dyn_cast<pir::StrAttribute>()
-                                     .AsString() +
-                                 "@fetch";
-      }
-    }
+    VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op";
+    pir::Program new_program(ir_context());
+    auto output_var_name = BuildProgramFromOperation(op, &new_program);
 
-    // Execute program
-    exe_config_.create_local_scope = false;
+    // execute program
+    exe_config_->skip_gc_vars.insert(output_var_name);
     auto kernel_program =
-        paddle::dialect::PdOpLowerToKernelPass(temp_program.get());
-    paddle::framework::InterpreterCore core(phi::CPUPlace{},
-                                            fetch_var_names,
-                                            kernel_program->block(),
-                                            scope_,
-                                            exe_config_);
-
-    paddle::framework::FetchList fetch_list = core.Run({});
-
-    // TODO(liuyuanle): Support multiple output.
-    auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]);
-    std::unique_ptr<pir::Parameter> parameter =
-        std::make_unique<pir::Parameter>(
-            reinterpret_cast<void*>(out_tensor.data()),
-            out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
-            op->result(0).type());
-
-    std::string param_name =
-        "@constant_folding_pass@_" + std::to_string(suffix_++);
-    exe_config_.skip_gc_vars.insert(param_name);
-
-    auto* param_var = scope_->Var(param_name);
-    auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
-    *param_tensor = out_tensor;
-    program->SetParameter(param_name, std::move(parameter));
-    // rewriter.SetInsertionPoint(op);
-    auto get_parameter_op =
-        rewriter.Build<pir::GetParameterOp>(param_name, op->result(0).type());
+        paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
+    paddle::framework::InterpreterCore core(
+        place_, {}, kernel_program->block(), scope_, *exe_config_);
+
+    core.Run({});
 
+    // TODO(liuyuanle): support multiple output
+    auto get_parameter_op = rewriter.Build<pir::GetParameterOp>(
+        output_var_name, op->result(0).type());
+    get_parameter_op->set_attribute(
+        kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));
+
+    VLOG(4) << "constant_folding_pass applied on [" << op->name() << "] op";
     rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
     rewriter.EraseOp(op);
   }
 
  private:
-  std::unique_ptr<pir::Program> BuildProgramFromOperation(
-      pir::Operation* op) const {
-    auto program = std::make_unique<pir::Program>(ir_context());
-    pir::Builder builder = pir::Builder(ir_context(), program->block());
+  std::string BuildProgramFromOperation(pir::Operation* op,
+                                        pir::Program* new_program) const {
+    pir::Builder builder = pir::Builder(ir_context(), new_program->block());
+    std::string output_var_name =
+        "constant_folding@" + std::to_string((*counter_)++);
 
     // prepare op inputs
     std::vector<pir::Value> op_inputs;
@@ -147,15 +121,14 @@ class ConstantFoldingPattern : public pir::RewritePattern {
           phi::errors::InvalidArgument(
               "Op's input must be a dense tensor type."));
 
-      auto [param_name, param] =
-          pir::GetParameterFromValue(op->operand_source(i));
-      program->SetParameter(param_name,
-                            std::make_unique<pir::Parameter>(*param));
-
+      const auto& param_name =
+          pir::GetParameterNameFromValue(op->operand_source(i));
       auto* param_var = scope_->FindVar(param_name);
       PADDLE_ENFORCE_NOT_NULL(
           param_var,
           phi::errors::InvalidArgument("Parameter var not in scope."));
+      output_var_name = output_var_name + "_" + param_name;
+      deleted_vars_->push_back(param_name);
 
       auto get_parameter_op = builder.Build<pir::GetParameterOp>(
           param_name, op->operand_source(i).type());
@@ -171,7 +144,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
     auto* temp_op =
         builder.Build(op_inputs, op->attributes(), output_types, op->info());
 
-    // TODO(liuyuanle): Support multiple output.
+    // TODO(liuyuanle): support multiple output
     // for (uint32_t i = 0; i < op->num_results(); i++) {
     PADDLE_ENFORCE_EQ(
         temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
@@ -179,52 +152,72 @@ class ConstantFoldingPattern : public pir::RewritePattern {
         phi::errors::InvalidArgument(
             "Op's output must be a dense tensor type."));
 
-    builder.Build<paddle::dialect::FetchOp>(
-        temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0);
+    builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name);
     // }
 
-    return program;
+    return output_var_name;
   }
 
  private:
+  size_t* counter_{nullptr};
+  phi::Place place_;
   paddle::framework::Scope* scope_{nullptr};
-  inline static size_t suffix_{0};
-  inline static paddle::framework::interpreter::ExecutionConfig exe_config_{};
+  paddle::framework::interpreter::ExecutionConfig* exe_config_{nullptr};
+  std::vector<std::string>* deleted_vars_{nullptr};
 };
 
 class ConstantFoldingPass : public pir::Pass {
  public:
-  ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {}
+  ConstantFoldingPass(const phi::Place& place, paddle::framework::Scope* scope)
+      : pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) {
+    PADDLE_ENFORCE_NOT_NULL(
+        scope_, phi::errors::InvalidArgument("scope can not be nullptr"));
+  }
 
+ private:
   bool Initialize(pir::IrContext* context) override {
     pir::RewritePatternSet ps(context);
-    ps.Add<ConstantFoldingPattern>(context, &scope_);
+    ps.Add<ConstantFoldingPattern>(
+        context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
     patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
     return true;
   }
 
   void Run(pir::Operation* op) override {
+    size_t op_nums = op->GetParentProgram()->block()->size();
     pir::GreedyRewriteConfig cfg;
     cfg.use_top_down_traversal = true;
     cfg.max_iterations = 10;
     pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
+
+    // delete old parameter var
+    scope_->EraseVars(deleted_vars_);
+    LOG(INFO) << " ------ constant_folding_pass done: [" << counter_ << "/"
+              << op_nums << "]";
   }
 
   bool CanApplyOn(pir::Operation* op) const override {
+    // TODO(liuyuanle): remove op->isa<::pir::ModuleOp>()
     return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
   }
 
  private:
+  size_t counter_{0};
+  phi::Place place_;
+  paddle::framework::Scope* scope_{nullptr};
+  paddle::framework::interpreter::ExecutionConfig exe_config_{};
+  std::vector<std::string> deleted_vars_;
+
   pir::FrozenRewritePatternSet patterns_;
-  paddle::framework::Scope scope_;
 };
 
 }  // namespace
 
 namespace pir {
 
-std::unique_ptr<Pass> CreateConstantFoldingPass() {
-  return std::make_unique<ConstantFoldingPass>();
+std::unique_ptr<Pass> CreateConstantFoldingPass(
+    const phi::Place& place, paddle::framework::Scope* scope) {
+  return std::make_unique<ConstantFoldingPass>(place, scope);
 }
 
 }  // namespace pir
diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.h b/paddle/fluid/pir/transforms/constant_folding_pass.h
index b49c9d90493b1c..0939ee589d448d 100644
--- a/paddle/fluid/pir/transforms/constant_folding_pass.h
+++ b/paddle/fluid/pir/transforms/constant_folding_pass.h
@@ -15,12 +15,20 @@
 #pragma once
 
 #include <memory>
+#include "paddle/phi/common/place.h"
 #include "paddle/pir/core/dll_decl.h"
 
+namespace paddle {
+namespace framework {
+class Scope;
+}
+}  // namespace paddle
+
 namespace pir {
 
 class Pass;
 
-IR_API std::unique_ptr<Pass> CreateConstantFoldingPass();
+IR_API std::unique_ptr<Pass> CreateConstantFoldingPass(
+    const phi::Place& place, paddle::framework::Scope* scope);
 
 }  // namespace pir
diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
index 0760a9de420b82..bfde883cac67a9 100644
--- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
+++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
@@ -35,8 +35,7 @@ class DeadCodeEliminationPattern : public pir::RewritePattern {
   }
 
   bool Match(pir::Operation* op) const override {
-    if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() ||
-        op->isa<paddle::dialect::ShadowOutputOp>()) {
+    if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>()) {
       return false;
     }
     return op->use_empty();
diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc
index 8bd5028688b133..1d7c226197668f 100644
--- a/paddle/fluid/pir/transforms/transform_general_functions.cc
+++ b/paddle/fluid/pir/transforms/transform_general_functions.cc
@@ -22,8 +22,7 @@
 
 namespace pir {
 
-std::pair<std::string, pir::Parameter*> GetParameterFromValue(
-    pir::Value value) {
+std::string GetParameterNameFromValue(pir::Value value) {
   pir::GetParameterOp op =
       value.dyn_cast<OpResult>().owner()->dyn_cast<pir::GetParameterOp>();
   PADDLE_ENFORCE_NOT_NULL(
@@ -37,10 +36,7 @@ std::pair<std::string, pir::Parameter*> GetParameterFromValue(
                          .at(op.attributes_name[0])
                          .dyn_cast<pir::StrAttribute>()
                          .AsString();
-  pir::Parameter* param = program->GetParameter(name);
-  PADDLE_ENFORCE_NOT_NULL(
-      param, phi::errors::InvalidArgument("Parameter should not be null."));
-  return {name, param};
+  return name;
 }
 
 const phi::DDim& GetShapeFromValue(pir::Value value) {
diff --git a/paddle/fluid/pir/transforms/transform_general_functions.h b/paddle/fluid/pir/transforms/transform_general_functions.h
index 77c790235b8329..0d35ff776ce8c1 100644
--- a/paddle/fluid/pir/transforms/transform_general_functions.h
+++ b/paddle/fluid/pir/transforms/transform_general_functions.h
@@ -25,16 +25,16 @@
 namespace pir {
 
 /**
- * @brief Get the [name, parameter] pair of pararmeter from a value.
+ * @brief Get the name of pararmeter from a value.
  *
  * @note The value must be a output of a GetParameterOp.
  *
  * @param pir::Value
  *
- * @return std::pair<std::string, pir::Parameter*>
+ * @return std::string
  */
 
-std::pair<std::string, pir::Parameter*> GetParameterFromValue(pir::Value value);
+std::string GetParameterNameFromValue(pir::Value value);
 
 /**
  * @brief Get tensor's shape from a value.
diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
index 28661875b3bafc..df112a6a3f44c3 100644
--- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
+++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
@@ -20,11 +20,15 @@
 #include <sstream>
 #include <vector>
 
+#include "paddle/fluid/framework/scope.h"
 #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
+#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
+#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
+#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
 #include "paddle/fluid/pir/transforms/constant_folding_pass.h"
 #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
 #include "paddle/fluid/pir/transforms/transform_general_functions.h"
-#include "paddle/phi/core/kernel_registry.h"
+
 #include "paddle/pir/core/builder.h"
 #include "paddle/pir/core/builtin_attribute.h"
 #include "paddle/pir/core/builtin_dialect.h"
@@ -44,13 +48,9 @@
 #include "paddle/pir/pattern_rewrite/pattern_match.h"
 #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
 
-// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
-// paddle/fluid/pir/dialect/CMakeLists.txt.
-#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
-
-#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
-#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
+#include "paddle/phi/common/place.h"
 #include "paddle/phi/core/ddim.h"
+#include "paddle/phi/core/kernel_registry.h"
 
 // build Conv2dFusionOp
 #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
@@ -1120,10 +1120,10 @@ TEST(pattern_rewrite, Patterns) {
   BuildProgram(builder);
 
   EXPECT_EQ(program.block()->size(), 11u);
-
+  paddle::framework::Scope scope;
   pir::PassManager pm(ctx);
   pm.AddPass(std::make_unique<TestPass>());
-  pm.AddPass(pir::CreateConstantFoldingPass());
+  pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope));
   pm.AddPass(pir::CreateDeadCodeEliminationPass());
   pm.EnablePassTiming();
   pm.EnableIRPrinting();

From 1185d56e67de7c16246737afb3b532a62b334f2f Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Wed, 8 Nov 2023 08:00:34 +0000
Subject: [PATCH 4/8] update

---
 paddle/fluid/pir/transforms/constant_folding_pass.cc | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc
index 66019b3e2c288c..5934294769752e 100644
--- a/paddle/fluid/pir/transforms/constant_folding_pass.cc
+++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc
@@ -109,8 +109,6 @@ class ConstantFoldingPattern : public pir::RewritePattern {
   std::string BuildProgramFromOperation(pir::Operation* op,
                                         pir::Program* new_program) const {
     pir::Builder builder = pir::Builder(ir_context(), new_program->block());
-    std::string output_var_name =
-        "constant_folding@" + std::to_string((*counter_)++);
 
     // prepare op inputs
     std::vector<pir::Value> op_inputs;
@@ -127,7 +125,6 @@ class ConstantFoldingPattern : public pir::RewritePattern {
       PADDLE_ENFORCE_NOT_NULL(
           param_var,
           phi::errors::InvalidArgument("Parameter var not in scope."));
-      output_var_name = output_var_name + "_" + param_name;
       deleted_vars_->push_back(param_name);
 
       auto get_parameter_op = builder.Build<pir::GetParameterOp>(
@@ -152,6 +149,10 @@ class ConstantFoldingPattern : public pir::RewritePattern {
         phi::errors::InvalidArgument(
             "Op's output must be a dense tensor type."));
 
+    std::stringstream ss;
+    ss << std::chrono::high_resolution_clock::now().time_since_epoch().count();
+    std::string output_var_name = ss.str() + std::to_string((*counter_)++);
+
     builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name);
     // }
 

From 700ebc71e1e7b1641f5038c52a763a4e20cf11a6 Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Wed, 8 Nov 2023 08:07:05 +0000
Subject: [PATCH 5/8] update

---
 paddle/fluid/pir/transforms/constant_folding_pass.cc | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc
index 5934294769752e..fb3d7de7331f38 100644
--- a/paddle/fluid/pir/transforms/constant_folding_pass.cc
+++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc
@@ -151,7 +151,8 @@ class ConstantFoldingPattern : public pir::RewritePattern {
 
     std::stringstream ss;
     ss << std::chrono::high_resolution_clock::now().time_since_epoch().count();
-    std::string output_var_name = ss.str() + std::to_string((*counter_)++);
+    std::string output_var_name =
+        "constant_folding@_" + ss.str() + std::to_string((*counter_)++);
 
     builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name);
     // }

From a6900b3207edeee67b2c9f8a806426b1af13e812 Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Wed, 8 Nov 2023 08:53:56 +0000
Subject: [PATCH 6/8] fix compile

---
 paddle/fluid/inference/api/CMakeLists.txt     |  6 +----
 .../pir/transforms/constant_folding_pass.cc   | 24 ++++++++++---------
 2 files changed, 14 insertions(+), 16 deletions(-)

diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt
index 4aadb77aedc0e3..f15bd26c4476a6 100755
--- a/paddle/fluid/inference/api/CMakeLists.txt
+++ b/paddle/fluid/inference/api/CMakeLists.txt
@@ -64,10 +64,6 @@ if(WIN32)
   target_link_libraries(paddle_inference_api phi)
 endif()
 
-set(PIR_PASS_DEPS
-    pd_constant_folding_pass dead_code_elimination_pass pd_op_to_kernel_pass
-    pd_inplace_pass replace_fetch_with_shadow_output_pass)
-
 set(inference_deps
     ${analysis_deps}
     paddle_inference_api
@@ -75,7 +71,7 @@ set(inference_deps
     analysis_config
     naive_executor
     ${GLOB_PASS_LIB}
-    ${PIR_PASS_DEPS})
+    transform)
 
 if(WITH_GPU AND TENSORRT_FOUND)
   set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter)
diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc
index fb3d7de7331f38..a89caad7d65ff1 100644
--- a/paddle/fluid/pir/transforms/constant_folding_pass.cc
+++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc
@@ -113,11 +113,11 @@ class ConstantFoldingPattern : public pir::RewritePattern {
     // prepare op inputs
     std::vector<pir::Value> op_inputs;
     for (uint32_t i = 0; i < op->num_operands(); i++) {
-      PADDLE_ENFORCE_EQ(
-          op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
-          true,
-          phi::errors::InvalidArgument(
-              "Op's input must be a dense tensor type."));
+      // PADDLE_ENFORCE_EQ(
+      //     op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
+      //     true,
+      //     phi::errors::InvalidArgument(
+      //         "Op [%s] input must be a dense tensor type.", op->name()));
 
       const auto& param_name =
           pir::GetParameterNameFromValue(op->operand_source(i));
@@ -125,7 +125,9 @@ class ConstantFoldingPattern : public pir::RewritePattern {
       PADDLE_ENFORCE_NOT_NULL(
           param_var,
           phi::errors::InvalidArgument("Parameter var not in scope."));
-      deleted_vars_->push_back(param_name);
+      if (op->operand_source(i).use_count() == 1) {
+        deleted_vars_->push_back(param_name);
+      }
 
       auto get_parameter_op = builder.Build<pir::GetParameterOp>(
           param_name, op->operand_source(i).type());
@@ -143,11 +145,11 @@ class ConstantFoldingPattern : public pir::RewritePattern {
 
     // TODO(liuyuanle): support multiple output
     // for (uint32_t i = 0; i < op->num_results(); i++) {
-    PADDLE_ENFORCE_EQ(
-        temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
-        true,
-        phi::errors::InvalidArgument(
-            "Op's output must be a dense tensor type."));
+    // PADDLE_ENFORCE_EQ(
+    //     temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
+    //     true,
+    //     phi::errors::InvalidArgument(
+    //         "Op [%s] output must be a dense tensor type.", temp_op->name()));
 
     std::stringstream ss;
     ss << std::chrono::high_resolution_clock::now().time_since_epoch().count();

From 400762de23d89c20380799590c9d05df21a5ffaa Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Wed, 8 Nov 2023 11:14:27 +0000
Subject: [PATCH 7/8] fix dce

---
 paddle/fluid/pir/transforms/dead_code_elimination_pass.cc | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
index bfde883cac67a9..9c6fcd9b3d9ca4 100644
--- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
+++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
@@ -17,6 +17,7 @@
 #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
 #include "paddle/pir/core/builtin_op.h"
 #include "paddle/pir/core/program.h"
+#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
 #include "paddle/pir/pass/pass.h"
 #include "paddle/pir/pass/pass_registry.h"
 #include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
@@ -35,7 +36,8 @@ class DeadCodeEliminationPattern : public pir::RewritePattern {
   }
 
   bool Match(pir::Operation* op) const override {
-    if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>()) {
+    if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() ||
+        op->isa<pir::YieldOp>()) {
       return false;
     }
     return op->use_empty();

From b4e4a1c5a7b8d8425e36adfbfc644eeb7927a099 Mon Sep 17 00:00:00 2001
From: yuanlehome <yuanlehome@163.com>
Date: Wed, 8 Nov 2023 12:04:42 +0000
Subject: [PATCH 8/8] enhance judgement

---
 .../pir/transforms/constant_folding_pass.cc   | 39 ++++++++++++-------
 1 file changed, 24 insertions(+), 15 deletions(-)

diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc
index a89caad7d65ff1..39daebc1a3b8f9 100644
--- a/paddle/fluid/pir/transforms/constant_folding_pass.cc
+++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc
@@ -72,10 +72,10 @@ class ConstantFoldingPattern : public pir::RewritePattern {
         op->isa<paddle::dialect::FeedOp>())
       return false;
 
-    // inputs must come from get parameter op
-    for (uint32_t i = 0; i < op->num_operands(); ++i)
-      if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>())
-        return false;
+    if (!ValidOp(op)) {
+      return false;
+    }
+
     return true;
   }
 
@@ -106,6 +106,26 @@ class ConstantFoldingPattern : public pir::RewritePattern {
   }
 
  private:
+  bool ValidOp(pir::Operation* op) const {
+    for (uint32_t i = 0; i < op->num_operands(); i++) {
+      // 1. inputs must come from get_parameter op
+      // 2. inputs must be a dense tensor type
+      if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>() ||
+          !op->operand_source(i)
+               .type()
+               .isa<paddle::dialect::DenseTensorType>()) {
+        return false;
+      }
+      // 3. outputs must be a dense tensor type
+      for (uint32_t i = 0; i < op->num_results(); i++) {
+        if (!op->result(i).type().isa<paddle::dialect::DenseTensorType>()) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
   std::string BuildProgramFromOperation(pir::Operation* op,
                                         pir::Program* new_program) const {
     pir::Builder builder = pir::Builder(ir_context(), new_program->block());
@@ -113,12 +133,6 @@ class ConstantFoldingPattern : public pir::RewritePattern {
     // prepare op inputs
     std::vector<pir::Value> op_inputs;
     for (uint32_t i = 0; i < op->num_operands(); i++) {
-      // PADDLE_ENFORCE_EQ(
-      //     op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
-      //     true,
-      //     phi::errors::InvalidArgument(
-      //         "Op [%s] input must be a dense tensor type.", op->name()));
-
       const auto& param_name =
           pir::GetParameterNameFromValue(op->operand_source(i));
       auto* param_var = scope_->FindVar(param_name);
@@ -145,11 +159,6 @@ class ConstantFoldingPattern : public pir::RewritePattern {
 
     // TODO(liuyuanle): support multiple output
     // for (uint32_t i = 0; i < op->num_results(); i++) {
-    // PADDLE_ENFORCE_EQ(
-    //     temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
-    //     true,
-    //     phi::errors::InvalidArgument(
-    //         "Op [%s] output must be a dense tensor type.", temp_op->name()));
 
     std::stringstream ss;
     ss << std::chrono::high_resolution_clock::now().time_since_epoch().count();