From 722da6a9ba5fd856325518fa1e4e51d5138f2741 Mon Sep 17 00:00:00 2001
From: phlrain <phliuhongyu@126.com>
Date: Wed, 6 Dec 2023 09:27:21 +0000
Subject: [PATCH] add build cinn pass constrain

---
 .../fluid/pir/transforms/build_cinn_pass.cc   | 51 +++++++++++++++++--
 1 file changed, 48 insertions(+), 3 deletions(-)

diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc
index 281f222501cb61..0716d979cd8f2c 100644
--- a/paddle/fluid/pir/transforms/build_cinn_pass.cc
+++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc
@@ -131,6 +131,8 @@ std::string GetDebugInfo(const std::unordered_set<std::string>& names) {
   return debug_info;
 }
 
+bool IsSupportCinn(pir::Operation* op);
+
 // In case of op has some attributes generated by FullOp, it need
 // implement OpPattern in pd_to_cinn_pass. Otherwise, we mark them
 // as unimplement ops.
@@ -139,17 +141,57 @@ bool UnimplementOps(pir::Operation* op) {
   // CINN
   if (op->isa<paddle::dialect::FullOp>()) {
     auto out = op->result(0);
-    if (out.use_count() > 0 &&
-        out.first_use().owner()->isa<paddle::dialect::UniformOp>()) {
-      return true;
+    if (out.use_count() > 0) {
+      return !IsSupportCinn(out.first_use().owner());
     }
+
+    return false;
   } else if (op->isa<paddle::dialect::DropoutOp>()) {
     return true;
   }
   return false;
 }
 
+bool HaveZeroDimInput(pir::Operation* op) {
+  bool have_zero_dim = false;
+  for (size_t i = 0; i < op->num_operands(); ++i) {
+    auto in = op->operand_source(i);
+    if (in) {
+      if (auto tensor_type =
+              in.type().dyn_cast<paddle::dialect::DenseTensorType>()) {
+        if (tensor_type.dims().size() == 0) {
+          have_zero_dim = true;
+        }
+      }
+    }
+  }
+
+  return have_zero_dim;
+}
+
+bool AllInputDenseTensor(pir::Operation* op) {
+  bool all_denese_tensor = true;
+  for (size_t i = 0; i < op->num_operands(); ++i) {
+    auto in = op->operand_source(i);
+    if (in) {
+      if (!(in.type().isa<paddle::dialect::DenseTensorType>())) {
+        all_denese_tensor = false;
+      }
+    }
+  }
+
+  return all_denese_tensor;
+}
+
 bool IsSupportCinn(pir::Operation* op) {
+  if (!AllInputDenseTensor(op)) {
+    return false;
+  }
+
+  if (HaveZeroDimInput(op)) {
+    return false;
+  }
+
   auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
   auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
   VLOG(4) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops);
@@ -162,6 +204,9 @@ bool IsSupportCinn(pir::Operation* op) {
 
   // Strip the dialect, like pd_op.abs -> abs
   const auto op_name = CompatibleInfo::OpName(*op);
+  if (op_name == "matmul") {
+    return false;
+  }
   OpTransInfo trans_info;
   bool is_support = CompatibleInfo::IsSupportCinn(*op) &&
                     !trans_info.default_deny_ops().count(op_name);