diff --git a/paddle/cinn/ir/ir_base.h b/paddle/cinn/ir/ir_base.h index c333448d029ae0..0047100ebcfdfc 100644 --- a/paddle/cinn/ir/ir_base.h +++ b/paddle/cinn/ir/ir_base.h @@ -110,16 +110,23 @@ class Dim; macro__(Product) \ macro__(Sum) \ macro__(PrimitiveNode) \ - macro__(IntrinsicOp) \ macro__(_BufferRange_) \ macro__(ScheduleBlock) \ macro__(ScheduleBlockRealize) \ macro__(_Dim_) \ +#define NODETY_CONTROL_OP_FOR_INTRINSIC(macro__) \ + macro__(IntrinsicOp) \ #define NODETY_FORALL(__m) \ NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ NODETY_OP_FOR_EACH(__m) \ + NODETY_CONTROL_OP_FOR_INTRINSIC(__m) \ + NODETY_CONTROL_OP_FOR_EACH(__m) + +#define NODETY_FORALL_EXCEPT_INTRINSIC(__m) \ + NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ + NODETY_OP_FOR_EACH(__m) \ NODETY_CONTROL_OP_FOR_EACH(__m) // clang-format on diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index ac2f0317e9213f..e4ebaca653bae9 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -15,6 +15,8 @@ #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include +#include "paddle/cinn/ir/intrinsic_ops.h" +#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" @@ -71,8 +73,71 @@ struct IrNodesCollector : public IRVisitorRequireReImpl { } \ } - NODETY_FORALL(__m) + NODETY_FORALL_EXCEPT_INTRINSIC(__m) #undef __m + + void Visit(const ir::IntrinsicOp* op) { + switch (op->getKind()) { +#define __(x) \ + case ir::IntrinsicKind::k##x: \ + Visit(llvm::dyn_cast(op)); \ + break; + + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } + } + + void Visit(const ir::intrinsics::GetAddr* x) { + if (x->data.defined()) { + Visit(&(x->data)); + } + } + + void Visit(const ir::intrinsics::BufferGetDataHandle* x) { + if (x->buffer.defined()) { + Visit(&(x->buffer)); + } + } + + void Visit(const ir::intrinsics::BufferGetDataConstHandle* x) { + if (x->buffer.defined()) { + Visit(&(x->buffer)); + } + } + + void Visit(const ir::intrinsics::PodValueToX* x) { + if (x->pod_value_ptr.defined()) { + Visit(&(x->pod_value_ptr)); + } + } + + void Visit(const ir::intrinsics::BufferCreate* x) { + if (x->buffer.defined()) { + Visit(&(x->buffer)); + } + } + + void Visit(const ir::intrinsics::ArgsConstruct* x) { + if (x->var.defined()) { + Expr convert = Expr(x->var); + Visit(&convert); + } + for (int i = 0; i < x->args.size(); ++i) { + if (x->args[i].defined()) { + Visit(&(x->args[i])); + } + } + } + + void Visit(const ir::intrinsics::BuiltinIntrin* x) { + for (int i = 0; i < x->args.size(); ++i) { + if (x->args[i].defined()) { + Visit(&(x->args[i])); + } + } + } + std::set visited_; };