diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc index 281f222501cb61..0716d979cd8f2c 100644 --- a/paddle/fluid/pir/transforms/build_cinn_pass.cc +++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc @@ -131,6 +131,8 @@ std::string GetDebugInfo(const std::unordered_set& names) { return debug_info; } +bool IsSupportCinn(pir::Operation* op); + // In case of op has some attributes generated by FullOp, it need // implement OpPattern in pd_to_cinn_pass. Otherwise, we mark them // as unimplement ops. @@ -139,17 +141,57 @@ bool UnimplementOps(pir::Operation* op) { // CINN if (op->isa()) { auto out = op->result(0); - if (out.use_count() > 0 && - out.first_use().owner()->isa()) { - return true; + if (out.use_count() > 0) { + return !IsSupportCinn(out.first_use().owner()); } + + return false; } else if (op->isa()) { return true; } return false; } +bool HaveZeroDimInput(pir::Operation* op) { + bool have_zero_dim = false; + for (size_t i = 0; i < op->num_operands(); ++i) { + auto in = op->operand_source(i); + if (in) { + if (auto tensor_type = + in.type().dyn_cast()) { + if (tensor_type.dims().size() == 0) { + have_zero_dim = true; + } + } + } + } + + return have_zero_dim; +} + +bool AllInputDenseTensor(pir::Operation* op) { + bool all_denese_tensor = true; + for (size_t i = 0; i < op->num_operands(); ++i) { + auto in = op->operand_source(i); + if (in) { + if (!(in.type().isa())) { + all_denese_tensor = false; + } + } + } + + return all_denese_tensor; +} + bool IsSupportCinn(pir::Operation* op) { + if (!AllInputDenseTensor(op)) { + return false; + } + + if (HaveZeroDimInput(op)) { + return false; + } + auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); VLOG(4) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops); @@ -162,6 +204,9 @@ bool IsSupportCinn(pir::Operation* op) { // Strip the dialect, like pd_op.abs -> abs const auto op_name = CompatibleInfo::OpName(*op); + if (op_name == "matmul") { + return false; + } OpTransInfo trans_info; bool is_support = CompatibleInfo::IsSupportCinn(*op) && !trans_info.default_deny_ops().count(op_name);