diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index f1642827cd48..004429148bb4 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -206,69 +206,51 @@ adding the code back to the central repo. To ease the speed of dispatching, we a Since usually one ``Object`` could be referenced in multiple places in the language, we use a shared_ptr to keep track of reference. We use ``ObjectRef`` class to represent a reference to the ``Object``. We can roughly view ``ObjectRef`` class as shared_ptr to the ``Object`` container. -We can also define subclass ``ObjectRef`` to hold each subtypes of ``Object``. Each subclass of ``Object`` needs to define the VisitAttr function. +We can also define subclass ``ObjectRef`` to hold each subtypes of ``Object``. Each subclass of ``Object`` needs to define the +RegisterReflection function. -.. code:: c - class AttrVisitor { - public: - virtual void Visit(const char* key, double* value) = 0; - virtual void Visit(const char* key, int64_t* value) = 0; - virtual void Visit(const char* key, uint64_t* value) = 0; - virtual void Visit(const char* key, int* value) = 0; - virtual void Visit(const char* key, bool* value) = 0; - virtual void Visit(const char* key, std::string* value) = 0; - virtual void Visit(const char* key, void** value) = 0; - virtual void Visit(const char* key, Type* value) = 0; - virtual void Visit(const char* key, ObjectRef* value) = 0; - // ... - }; - - class BaseAttrsNode : public Object { - public: - virtual void VisitAttrs(AttrVisitor* v) {} - // ... - }; - -Each ``Object`` subclass will override this to visit its members. Here is an example implementation of TensorNode. +Each ``Object`` subclass will override this to register its members. Here is an example implementation of IntImmNode. .. code:: c - class TensorNode : public Object { - public: - /*! \brief The shape of the tensor */ - Array shape; - /*! \brief data type in the content of the tensor */ - Type dtype; - /*! \brief the source operation, can be None */ - Operation op; - /*! \brief the output index from source operation */ - int value_index{0}; - /*! \brief constructor */ - TensorNode() {} - - void VisitAttrs(AttrVisitor* v) final { - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); - v->Visit("op", &op); - v->Visit("value_index", &value_index); - } - }; - -In the above examples, both ``Operation`` and ``Array`` are ObjectRef. -The VisitAttrs gives us a reflection API to visit each member of the object. + class IntImmNode : public PrimExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &IntImmNode::value); + } + + bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dtype); + hash_reduce(value); + } + + static constexpr const char* _type_key = "IntImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); + }; + // in cc file + TVM_FFI_STATIC_INIT_BLOCK({ IntImmNode::RegisterReflection(); }); + +The RegisterReflection gives us a reflection API to register each member of the object. We can use this function to visit the node and serialize any language object recursively. It also allows us to get members of an object easily in front-end language. -For example, in the following code, we accessed the op field of the TensorNode. +For example, we can access the value field of the IntImmNode. .. code:: python import tvm - from tvm import te - x = te.placeholder((3,4), name="x") - # access the op field of TensorNode - print(x.op.name) + x = tvm.tir.IntImm("int32", 1) + # access the value field of IntImmNode + print(x.value) New ``Object`` can be added to C++ without changing the front-end runtime, making it easy to make extensions to the compiler stack. Note that this is not the fastest way to expose members to front-end language, but might be one of the simplest diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 6527643fbf61..29fd49384cf6 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -26,6 +26,7 @@ #ifndef TVM_IR_DIAGNOSTIC_H_ #define TVM_IR_DIAGNOSTIC_H_ +#include #include #include @@ -65,13 +66,16 @@ class DiagnosticNode : public Object { /*! \brief The diagnostic message. */ String message; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { - v->Visit("level", &level); - v->Visit("span", &span); - v->Visit("message", &message); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("level", &DiagnosticNode::level) + .def_ro("span", &DiagnosticNode::span) + .def_ro("message", &DiagnosticNode::message); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const DiagnosticNode* other, SEqualReducer equal) const { return equal(this->level, other->level) && equal(this->span, other->span) && equal(this->message, other->message); @@ -165,8 +169,12 @@ class DiagnosticRendererNode : public Object { public: ffi::TypedFunction renderer; - // override attr visitor - void VisitAttrs(AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("renderer", &DiagnosticRendererNode::renderer); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "DiagnosticRenderer"; TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object); @@ -199,11 +207,15 @@ class DiagnosticContextNode : public Object { /*! \brief The renderer set for the context. */ DiagnosticRenderer renderer; - void VisitAttrs(AttrVisitor* v) { - v->Visit("module", &module); - v->Visit("diagnostics", &diagnostics); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("module", &DiagnosticContextNode::module) + .def_ro("diagnostics", &DiagnosticContextNode::diagnostics); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const DiagnosticContextNode* other, SEqualReducer equal) const { return equal(module, other->module) && equal(diagnostics, other->diagnostics); } diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index ab5cf31c6c86..062a5212de2c 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -25,6 +25,7 @@ #define TVM_IR_ENV_FUNC_H_ #include +#include #include #include @@ -48,7 +49,12 @@ class EnvFuncNode : public Object { /*! \brief constructor */ EnvFuncNode() {} - void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name", &EnvFuncNode::name); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { // name uniquely identifies the env function. diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 61c170b36639..3bb43a1594de 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -217,6 +217,11 @@ class BaseFuncNode : public RelaxExprNode { return LinkageType::kInternal; } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("attrs", &BaseFuncNode::attrs); + } + static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelaxExprNode); diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 57bc57fd099a..4fbfeefa2399 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -25,6 +25,7 @@ #ifndef TVM_IR_GLOBAL_INFO_H_ #define TVM_IR_GLOBAL_INFO_H_ +#include #include #include @@ -68,12 +69,17 @@ class VDeviceNode : public GlobalInfoNode { */ int vdevice_id; MemoryScope memory_scope; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("target", &target); - v->Visit("vdevice_id", &vdevice_id); - v->Visit("memory_scope", &memory_scope); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("target", &VDeviceNode::target) + .def_ro("vdevice_id", &VDeviceNode::vdevice_id) + .def_ro("memory_scope", &VDeviceNode::memory_scope); } + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DLL bool SEqualReduce(const VDeviceNode* other, SEqualReducer equal) const { return equal(target, other->target) && equal(vdevice_id, other->vdevice_id) && equal(memory_scope, other->memory_scope); @@ -103,7 +109,13 @@ class VDevice : public GlobalInfo { */ class DummyGlobalInfoNode : public GlobalInfoNode { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "DummyGlobalInfo"; TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer equal) const { diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h index 9ce0da5e02a3..827b643a9b64 100644 --- a/include/tvm/ir/global_var_supply.h +++ b/include/tvm/ir/global_var_supply.h @@ -27,6 +27,7 @@ #include #include +#include "tvm/ffi/reflection/reflection.h" #include "tvm/ir/expr.h" #include "tvm/ir/module.h" #include "tvm/ir/name_supply.h" @@ -75,7 +76,12 @@ class GlobalVarSupplyNode : public Object { */ void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false); - void VisitAttrs(AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ NameSupply name_supply_; diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index b1ef86c12c58..c571c09777fd 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -26,6 +26,7 @@ #ifndef TVM_IR_INSTRUMENT_H_ #define TVM_IR_INSTRUMENT_H_ +#include #include #include @@ -136,7 +137,12 @@ class PassInstrumentNode : public Object { */ virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0; - void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name", &PassInstrumentNode::name); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "instrument.PassInstrument"; TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 994f3a4bb86a..41c8cffbc21f 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -77,7 +78,7 @@ class IRModuleNode : public Object { * * \return The result * - * \tparam TOBjectRef the expected object type. + * \tparam TObjectRef the expected object type. * \throw Error if the key exists but the value does not match TObjectRef * * \code @@ -129,14 +130,18 @@ class IRModuleNode : public Object { IRModuleNode() : source_map() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("functions", &functions); - v->Visit("global_var_map_", &global_var_map_); - v->Visit("source_map", &source_map); - v->Visit("attrs", &attrs); - v->Visit("global_infos", &global_infos); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("functions", &IRModuleNode::functions) + .def_ro("global_var_map_", &IRModuleNode::global_var_map_) + .def_ro("source_map", &IRModuleNode::source_map) + .def_ro("attrs", &IRModuleNode::attrs) + .def_ro("global_infos", &IRModuleNode::global_infos); } + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const; TVM_DLL void SHashReduce(SHashReducer hash_reduce) const; diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index df0fba16cb6c..2fbf42fd9c1a 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -24,14 +24,15 @@ #ifndef TVM_IR_NAME_SUPPLY_H_ #define TVM_IR_NAME_SUPPLY_H_ +#include +#include + #include #include #include #include #include -#include "tvm/ir/expr.h" - namespace tvm { /*! @@ -80,7 +81,7 @@ class NameSupplyNode : public Object { */ bool ContainsName(const String& name, bool add_prefix = true); - void VisitAttrs(AttrVisitor* v) {} + static constexpr bool _type_has_method_visit_attrs = false; // Prefix for all GlobalVar names. It can be empty. std::string prefix_; diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 9c758a52b384..3e864d2d4bc2 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -26,6 +26,7 @@ #define TVM_IR_OP_H_ #include +#include #include #include #include @@ -90,16 +91,20 @@ class OpNode : public RelaxExprNode { */ int32_t support_level = 10; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("op_type", &op_type); - v->Visit("description", &description); - v->Visit("arguments", &arguments); - v->Visit("attrs_type_key", &attrs_type_key); - v->Visit("num_inputs", &num_inputs); - v->Visit("support_level", &support_level); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &OpNode::name) + .def_ro("op_type", &OpNode::op_type) + .def_ro("description", &OpNode::description) + .def_ro("arguments", &OpNode::arguments) + .def_ro("attrs_type_key", &OpNode::attrs_type_key) + .def_ro("num_inputs", &OpNode::num_inputs) + .def_ro("support_level", &OpNode::support_level); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const OpNode* other, SEqualReducer equal) const { // pointer equality is fine as there is only one op with the same name. return this == other; diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 83e2f4f375bf..27ef33a035ae 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -24,6 +24,7 @@ #define TVM_IR_SOURCE_MAP_H_ #include +#include #include #include @@ -46,8 +47,13 @@ class SourceNameNode : public Object { public: /*! \brief The source name. */ String name; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name", &SourceNameNode::name); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr bool _type_has_method_sequal_reduce = true; @@ -96,14 +102,18 @@ class SpanNode : public Object { /*! \brief The end column number. */ int end_column; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { - v->Visit("source_name", &source_name); - v->Visit("line", &line); - v->Visit("column", &column); - v->Visit("end_line", &end_line); - v->Visit("end_column", &end_column); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("source_name", &SpanNode::source_name) + .def_ro("line", &SpanNode::line) + .def_ro("column", &SpanNode::column) + .def_ro("end_line", &SpanNode::end_line) + .def_ro("end_column", &SpanNode::end_column); } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr bool _type_has_method_sequal_reduce = true; bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { @@ -134,12 +144,13 @@ class SequentialSpanNode : public SpanNode { /*! \brief The original source list of spans to construct a sequential span. */ Array spans; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { - SpanNode::VisitAttrs(v); - v->Visit("spans", &spans); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("spans", &SequentialSpanNode::spans); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "SequentialSpan"; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode); @@ -188,12 +199,15 @@ class SourceNode : public Object { /*! \brief A mapping of line breaks into the raw source. */ std::vector> line_map; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { - v->Visit("source_name", &source_name); - v->Visit("source", &source); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("source_name", &SourceNode::source_name) + .def_ro("source", &SourceNode::source); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "Source"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); }; @@ -218,8 +232,12 @@ class SourceMapObj : public Object { /*! \brief The source mapping. */ Map source_map; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("source_map", &SourceMapObj::source_map); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const SourceMapObj* other, SEqualReducer equal) const { return equal(source_map, other->source_map); diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 7d9ff940a816..353f0a69cbf3 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -57,6 +57,7 @@ #define TVM_IR_TRANSFORM_H_ #include +#include #include #include #include @@ -122,15 +123,19 @@ class PassContextNode : public Object { return GetConfig(key, Optional(default_value)); } - void VisitAttrs(AttrVisitor* v) { - v->Visit("opt_level", &opt_level); - v->Visit("required_pass", &required_pass); - v->Visit("disabled_pass", &disabled_pass); - v->Visit("instruments", &instruments); - v->Visit("config", &config); - v->Visit("diag_ctx", &diag_ctx); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("opt_level", &PassContextNode::opt_level) + .def_ro("required_pass", &PassContextNode::required_pass) + .def_ro("disabled_pass", &PassContextNode::disabled_pass) + .def_ro("instruments", &PassContextNode::instruments) + .def_ro("config", &PassContextNode::config) + .def_ro("diag_ctx", &PassContextNode::diag_ctx); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "transform.PassContext"; static constexpr bool _type_has_method_sequal_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); @@ -311,13 +316,17 @@ class PassInfoNode : public Object { PassInfoNode() = default; - void VisitAttrs(AttrVisitor* v) { - v->Visit("opt_level", &opt_level); - v->Visit("name", &name); - v->Visit("required", &required); - v->Visit("traceable", &traceable); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("opt_level", &PassInfoNode::opt_level) + .def_ro("name", &PassInfoNode::name) + .def_ro("required", &PassInfoNode::required) + .def_ro("traceable", &PassInfoNode::traceable); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "transform.PassInfo"; static constexpr bool _type_has_method_sequal_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); @@ -374,7 +383,7 @@ class PassNode : public Object { */ virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; - void VisitAttrs(AttrVisitor* v) {} + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "transform.Pass"; TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); @@ -432,11 +441,15 @@ class SequentialNode : public PassNode { /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - v->Visit("passes", &passes); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pass_info", &SequentialNode::pass_info) + .def_ro("passes", &SequentialNode::passes); } + static constexpr bool _type_has_method_visit_attrs = false; + /*! * \brief Get the pass information/meta data. */ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 2e49a9c5185b..5ca35449fc03 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -50,6 +50,7 @@ #define TVM_IR_TYPE_H_ #include +#include #include #include #include @@ -110,7 +111,12 @@ class PrimTypeNode : public TypeNode { */ runtime::DataType dtype; - void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype); @@ -159,11 +165,15 @@ class PointerTypeNode : public TypeNode { */ String storage_scope; - void VisitAttrs(AttrVisitor* v) { - v->Visit("element_type", &element_type); - v->Visit("storage_scope", &storage_scope); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("element_type", &PointerTypeNode::element_type) + .def_ro("storage_scope", &PointerTypeNode::storage_scope); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { // Make "global" equal to "" String lhs_scope = storage_scope.empty() ? "global" : storage_scope; @@ -208,11 +218,15 @@ class TupleTypeNode : public TypeNode { TupleTypeNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("fields", &fields); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("fields", &TupleTypeNode::fields) + .def_ro("span", &TupleTypeNode::span); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const { return equal(fields, other->fields); } @@ -274,12 +288,16 @@ class FuncTypeNode : public TypeNode { /*! \brief The type of return value. */ Type ret_type; - void VisitAttrs(AttrVisitor* v) { - v->Visit("arg_types", &arg_types); - v->Visit("ret_type", &ret_type); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("arg_types", &FuncTypeNode::arg_types) + .def_ro("ret_type", &FuncTypeNode::ret_type) + .def_ro("span", &FuncTypeNode::span); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const { // type params first as they defines type vars. return equal(arg_types, other->arg_types) && equal(ret_type, other->ret_type); diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index 721ae0932cdd..a47db66be553 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -121,29 +122,33 @@ class PrinterConfigNode : public Object { /*! \brief Object to be annotated. */ Map obj_to_annotate = Map(); - void VisitAttrs(AttrVisitor* v) { - v->Visit("binding_names", &binding_names); - v->Visit("show_meta", &show_meta); - v->Visit("ir_prefix", &ir_prefix); - v->Visit("tir_prefix", &tir_prefix); - v->Visit("relax_prefix", &relax_prefix); - v->Visit("module_alias", &module_alias); - v->Visit("buffer_dtype", &buffer_dtype); - v->Visit("int_dtype", &int_dtype); - v->Visit("float_dtype", &float_dtype); - v->Visit("verbose_expr", &verbose_expr); - v->Visit("indent_spaces", &indent_spaces); - v->Visit("print_line_numbers", &print_line_numbers); - v->Visit("num_context_lines", &num_context_lines); - v->Visit("syntax_sugar", &syntax_sugar); - v->Visit("show_object_address", &show_object_address); - v->Visit("show_all_struct_info", &show_all_struct_info); - v->Visit("path_to_underline", &path_to_underline); - v->Visit("path_to_annotate", &path_to_annotate); - v->Visit("obj_to_underline", &obj_to_underline); - v->Visit("obj_to_annotate", &obj_to_annotate); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("binding_names", &PrinterConfigNode::binding_names) + .def_ro("show_meta", &PrinterConfigNode::show_meta) + .def_ro("ir_prefix", &PrinterConfigNode::ir_prefix) + .def_ro("tir_prefix", &PrinterConfigNode::tir_prefix) + .def_ro("relax_prefix", &PrinterConfigNode::relax_prefix) + .def_ro("module_alias", &PrinterConfigNode::module_alias) + .def_ro("buffer_dtype", &PrinterConfigNode::buffer_dtype) + .def_ro("int_dtype", &PrinterConfigNode::int_dtype) + .def_ro("float_dtype", &PrinterConfigNode::float_dtype) + .def_ro("verbose_expr", &PrinterConfigNode::verbose_expr) + .def_ro("indent_spaces", &PrinterConfigNode::indent_spaces) + .def_ro("print_line_numbers", &PrinterConfigNode::print_line_numbers) + .def_ro("num_context_lines", &PrinterConfigNode::num_context_lines) + .def_ro("syntax_sugar", &PrinterConfigNode::syntax_sugar) + .def_ro("show_object_address", &PrinterConfigNode::show_object_address) + .def_ro("show_all_struct_info", &PrinterConfigNode::show_all_struct_info) + .def_ro("path_to_underline", &PrinterConfigNode::path_to_underline) + .def_ro("path_to_annotate", &PrinterConfigNode::path_to_annotate) + .def_ro("obj_to_underline", &PrinterConfigNode::obj_to_underline) + .def_ro("obj_to_annotate", &PrinterConfigNode::obj_to_annotate); } + static constexpr bool _type_has_method_visit_attrs = false; + Array GetBuiltinKeywords(); static constexpr const char* _type_key = "node.PrinterConfig"; diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h index 86bdcec30140..6c017ccc3446 100644 --- a/include/tvm/relax/binding_rewrite.h +++ b/include/tvm/relax/binding_rewrite.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAX_BINDING_REWRITE_H_ +#include #include #include #include @@ -67,11 +68,15 @@ class DataflowBlockRewriteNode : public Object { IRModule MutateIRModule(IRModule irmod); /*! \brief Visit attributes. */ - void VisitAttrs(AttrVisitor* v) { - v->Visit("dfb", &dfb_); - v->Visit("root_fn", &root_fn_); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("dfb", &DataflowBlockRewriteNode::dfb_) + .def_ro("root_fn", &DataflowBlockRewriteNode::root_fn_); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.DataflowBlockRewrite"; TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object); diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 36d07516086f..5e4a002467bf 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -215,7 +216,12 @@ class PatternSeqNode final : public Object { tvm::Array patterns; /*!< The sequence of DFPatterns */ std::vector pair_constraints; /*!< Constraints between the previous and next patterns */ - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("patterns", &PatternSeqNode::patterns); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.PatternSeq"; TVM_DECLARE_BASE_OBJECT_INFO(PatternSeqNode, Object); }; @@ -343,8 +349,12 @@ class ExprPatternNode : public DFPatternNode { public: Expr expr; /*!< The expression to match */ - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("expr", &ExprPatternNode::expr); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.ExprPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); }; @@ -368,8 +378,13 @@ class VarPatternNode : public DFPatternNode { public: String name; const String& name_hint() const { return name; } - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name", &VarPatternNode::name); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.VarPattern"; static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode); @@ -396,8 +411,13 @@ class VarPattern : public DFPattern { */ class DataflowVarPatternNode : public VarPatternNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr const char* _type_key = "relax.dpl.DataflowVarPattern"; - TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, DFPatternNode); + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, VarPatternNode); }; /*! @@ -437,8 +457,12 @@ class GlobalVarPattern : public DFPattern { */ class ConstantPatternNode : public DFPatternNode { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.ConstantPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); }; @@ -475,11 +499,15 @@ class CallPatternNode : public DFPatternNode { // Todo(relax-team): Dataflow pattern for StructInfo, and match sinfo_args - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("op", &op); - v->Visit("args", &args); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("op", &CallPatternNode::op) + .def_ro("args", &CallPatternNode::args); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.CallPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); }; @@ -498,7 +526,13 @@ class CallPattern : public DFPattern { class PrimArrPatternNode : public DFPatternNode { public: Array fields; /*!< The array to match */ - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("fields", &PrimArrPatternNode::fields); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.PrimArrPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimArrPatternNode, DFPatternNode); }; @@ -529,11 +563,15 @@ class FunctionPatternNode : public DFPatternNode { */ DFPattern body; /*!< The body of the function */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("params", ¶ms); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("params", &FunctionPatternNode::params) + .def_ro("body", &FunctionPatternNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.FunctionPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode); }; @@ -562,8 +600,12 @@ class TuplePatternNode : public DFPatternNode { public: tvm::Array fields; /*!< The fields of the tuple */ - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("fields", &TuplePatternNode::fields); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.TuplePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); }; @@ -586,8 +628,13 @@ class UnorderedTuplePatternNode : public DFPatternNode { public: tvm::Array fields; /*!< The fields of the tuple */ - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("fields", + &UnorderedTuplePatternNode::fields); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.UnorderedTuplePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(UnorderedTuplePatternNode, DFPatternNode); }; @@ -612,11 +659,14 @@ class TupleGetItemPatternNode : public DFPatternNode { DFPattern tuple; /*!< The tuple Expression */ int index; /*!< The index of the tuple with -1 meaning arbitrary */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tuple", &tuple); - v->Visit("index", &index); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("tuple", &TupleGetItemPatternNode::tuple) + .def_ro("index", &TupleGetItemPatternNode::index); } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.TupleGetItemPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); }; @@ -640,11 +690,15 @@ class AndPatternNode : public DFPatternNode { DFPattern left; /*!< The left hand side of the conjunction */ DFPattern right; /*!< The right hand side of the conjunction */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("left", &left); - v->Visit("right", &right); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("left", &AndPatternNode::left) + .def_ro("right", &AndPatternNode::right); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.AndPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(AndPatternNode, DFPatternNode); }; @@ -668,11 +722,15 @@ class OrPatternNode : public DFPatternNode { DFPattern left; /*!< The left hand side of the disjunction */ DFPattern right; /*!< The right hand side of the disjunction */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("left", &left); - v->Visit("right", &right); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("left", &OrPatternNode::left) + .def_ro("right", &OrPatternNode::right); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.OrPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(OrPatternNode, DFPatternNode); }; @@ -695,7 +753,12 @@ class NotPatternNode : public DFPatternNode { public: DFPattern reject; /*!< The pattern to reject */ - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("reject", &reject); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("reject", &NotPatternNode::reject); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.NotPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(NotPatternNode, DFPatternNode); @@ -717,8 +780,12 @@ class NotPattern : public DFPattern { */ class WildcardPatternNode : public DFPatternNode { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.WildcardPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); }; @@ -748,11 +815,15 @@ class StructInfoPatternNode : public DFPatternNode { DFPattern pattern; /*!< The pattern to match */ StructInfo struct_info; /*!< The type to match */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pattern", &pattern); - v->Visit("struct_info", &struct_info); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pattern", &StructInfoPatternNode::pattern) + .def_ro("struct_info", &StructInfoPatternNode::struct_info); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.StructInfoPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(StructInfoPatternNode, DFPatternNode); }; @@ -772,11 +843,15 @@ class ShapePatternNode : public DFPatternNode { DFPattern pattern; /*!< The root pattern to match */ Array shape; /*!< The shape to match */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pattern", &pattern); - v->Visit("shape", &shape); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pattern", &ShapePatternNode::pattern) + .def_ro("shape", &ShapePatternNode::shape); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.ShapePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); }; @@ -804,7 +879,12 @@ class SameShapeConstraintNode : public DFConstraintNode { std::tuple AsPrimExpr( std::function(const DFPatternNode*)> match_state) const override; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("args", &args); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("args", &SameShapeConstraintNode::args); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.SameShapeConstraint"; TVM_DECLARE_FINAL_OBJECT_INFO(SameShapeConstraintNode, DFConstraintNode); @@ -829,11 +909,15 @@ class DataTypePatternNode : public DFPatternNode { DFPattern pattern; /*!< The root pattern to match */ DataType dtype; /*!< The data type to match */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pattern", &pattern); - v->Visit("dtype", &dtype); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pattern", &DataTypePatternNode::pattern) + .def_ro("dtype", &DataTypePatternNode::dtype); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.DataTypePattern"; TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); }; @@ -857,11 +941,15 @@ class AttrPatternNode : public DFPatternNode { DFPattern pattern; /*!< The root pattern to match */ DictAttrs attrs; /*!< The attributes (a map/dictionary) to match */ - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pattern", &pattern); - v->Visit("attrs", &attrs); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pattern", &AttrPatternNode::pattern) + .def_ro("attrs", &AttrPatternNode::attrs); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.AttrPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); }; @@ -887,7 +975,14 @@ class ExternFuncPatternNode : public DFPatternNode { /*! \brief The external function name */ const String& global_symbol() const { return global_symbol_; } - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("global_symbol", &global_symbol_); } + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("global_symbol", + &ExternFuncPatternNode::global_symbol_); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.dpl.ExternFuncPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncPatternNode, DFPatternNode); diff --git a/include/tvm/relax/distributed/global_info.h b/include/tvm/relax/distributed/global_info.h index 67aeccd2970a..3b9663d04ba4 100644 --- a/include/tvm/relax/distributed/global_info.h +++ b/include/tvm/relax/distributed/global_info.h @@ -45,11 +45,16 @@ class DeviceMeshNode : public GlobalInfoNode { /*! \brief Optionally use range to represent device_ids*/ Optional device_range; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("shape", &shape); - v->Visit("device_ids", &device_ids); - v->Visit("device_range", &device_range); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("shape", &DeviceMeshNode::shape) + .def_ro("device_ids", &DeviceMeshNode::device_ids) + .def_ro("device_range", &DeviceMeshNode::device_range); } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.distributed.DeviceMesh"; bool SEqualReduce(const DeviceMeshNode* other, SEqualReducer equal) const { diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/struct_info.h index e101f20cddb4..f1dba2312206 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/struct_info.h @@ -44,10 +44,15 @@ class PlacementSpecNode : public Object { /*! \brief The kind of placement spec. Possible values: kSharding and kReplica. */ PlacementSpecKind kind; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("axis", &axis); - v->Visit("kind", &kind); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("axis", &PlacementSpecNode::axis) + .def_ro("kind", &PlacementSpecNode::kind); } + + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PlacementSpecNode* other, SEqualReducer equal) const { return equal(axis, other->axis) && equal(kind, other->kind); } @@ -81,7 +86,12 @@ class ShardingNode : public PlacementSpecNode { /*! \brief The dimension of tensor we shard*/ Integer sharding_dim; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("sharding_dim", &sharding_dim); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("sharding_dim", &ShardingNode::sharding_dim); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const ShardingNode* other, SEqualReducer equal) const { return equal(sharding_dim, other->sharding_dim); @@ -100,7 +110,12 @@ class PlacementNode : public Object { String ToString() const; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("dim_specs", &dim_specs); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("dim_specs", &PlacementNode::dim_specs); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const PlacementNode* other, SEqualReducer equal) const { return equal(dim_specs, other->dim_specs); @@ -144,13 +159,16 @@ class DTensorStructInfoNode : public StructInfoNode { */ Placement placement; - void VisitAttrs(AttrVisitor* v) { - v->Visit("device_mesh", &device_mesh); - v->Visit("placement", &placement); - v->Visit("tensor_sinfo", &tensor_sinfo); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("device_mesh", &DTensorStructInfoNode::device_mesh) + .def_ro("placement", &DTensorStructInfoNode::placement) + .def_ro("tensor_sinfo", &DTensorStructInfoNode::tensor_sinfo); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const DTensorStructInfoNode* other, SEqualReducer equal) const { return equal(tensor_sinfo, other->tensor_sinfo) && equal(device_mesh, other->device_mesh) && equal(placement, other->placement); diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index 16223d6bfb80..616982e78750 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -24,6 +24,7 @@ #define TVM_RELAX_EXEC_BUILDER_H_ #include +#include #include #include #include @@ -137,7 +138,12 @@ class ExecBuilderNode : public Object { */ TVM_DLL static ExecBuilder Create(); - void VisitAttrs(AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "relax.ExecBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 6197d1ed280f..e4049f23873c 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -54,7 +55,12 @@ class IdNode : public Object { */ String name_hint; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name_hint", &IdNode::name_hint); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const IdNode* other, SEqualReducer equal) const { return equal.FreeVarEqualImpl(this, other); @@ -160,15 +166,17 @@ class CallNode : public ExprNode { */ Array sinfo_args; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("op", &op); - v->Visit("args", &args); - v->Visit("attrs", &attrs); - v->Visit("sinfo_args", &sinfo_args); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("op", &CallNode::op) + .def_ro("args", &CallNode::args) + .def_ro("attrs", &CallNode::attrs) + .def_ro("sinfo_args", &CallNode::sinfo_args); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { // skip sinfo_args check for primitive ops. return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && @@ -221,12 +229,13 @@ class TupleNode : public ExprNode { /*! \brief the fields of the tuple */ tvm::Array fields; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("fields", &fields); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("fields", &TupleNode::fields); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { // struct info can be deterministically derived from fields. return equal(fields, other->fields); @@ -285,13 +294,15 @@ class TupleGetItemNode : public ExprNode { /*! \brief which value to get */ int index; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tuple_value", &tuple); - v->Visit("index", &index); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("tuple_value", &TupleGetItemNode::tuple) + .def_ro("index", &TupleGetItemNode::index); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { // struct info can be deterministically tuple and index. return equal(tuple, other->tuple) && equal(index, other->index); @@ -356,12 +367,13 @@ class ShapeExprNode : public LeafExprNode { /*! The values of the shape expression. */ Array values; - void VisitAttrs(AttrVisitor* v) { - v->Visit("values", &values); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("values", &ShapeExprNode::values); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { // struct info can be deterministically derived from values. return equal(values, other->values); @@ -392,12 +404,13 @@ class VarNode : public LeafExprNode { /*! \return The name hint of the variable */ const String& name_hint() const { return vid->name_hint; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("vid", &vid); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("vid", &VarNode::vid); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); @@ -432,12 +445,13 @@ class Var : public LeafExpr { */ class DataflowVarNode : public VarNode { public: - void VisitAttrs(AttrVisitor* v) { - v->Visit("vid", &vid); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(vid, other->vid) && equal(struct_info_, other->struct_info_); @@ -483,12 +497,13 @@ class ConstantNode : public LeafExprNode { /*! \return Whether it is scalar(ndim-0 tensor) */ bool is_scalar() const { return data->ndim == 0; } - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("data", &data); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("data", &ConstantNode::data); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { // struct info can be deterministically derived from data. return equal(data, other->data) && equal(struct_info_, other->struct_info_); @@ -530,12 +545,13 @@ class PrimValueNode : public LeafExprNode { /*! \brief The prim expr representing the value */ PrimExpr value; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &PrimValueNode::value); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PrimValueNode* other, SEqualReducer equal) const { // struct info can be deterministically derived from data. return equal(value, other->value); @@ -580,12 +596,13 @@ class StringImmNode : public LeafExprNode { /*! \brief The data value. */ String value; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &StringImmNode::value); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { // struct info can be deterministically derived from data. return equal(value, other->value); @@ -622,12 +639,13 @@ class DataTypeImmNode : public LeafExprNode { /*! \brief The data value. */ DataType value; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &DataTypeImmNode::value); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const DataTypeImmNode* other, SEqualReducer equal) const { // struct info can be deterministically derived from data. return equal(value, other->value); @@ -663,6 +681,14 @@ class BindingNode : public Object { Var var; mutable Span span; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("var", &BindingNode::var) + .def_ro("span", &BindingNode::span); + } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.expr.Binding"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -695,13 +721,16 @@ class MatchCastNode : public BindingNode { /*! \brief The struct info pattern to match to. */ StructInfo struct_info; - void VisitAttrs(AttrVisitor* v) { - v->Visit("var", &var); - v->Visit("value", &value); - v->Visit("struct_info", &struct_info); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("var", &MatchCastNode::var) + .def_ro("value", &MatchCastNode::value) + .def_ro("struct_info", &MatchCastNode::struct_info); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const; void SHashReduce(SHashReducer hash_reduce) const; @@ -728,12 +757,15 @@ class VarBindingNode : public BindingNode { /*! \brief The binding value. */ Expr value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("var", &var); - v->Visit("value", &value); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("var", &VarBindingNode::var) + .def_ro("value", &VarBindingNode::value); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const; void SHashReduce(SHashReducer hash_reduce) const; @@ -755,11 +787,13 @@ class BindingBlockNode : public Object { mutable Span span; Array bindings; - void VisitAttrs(AttrVisitor* v) { - v->Visit("span", &span); - v->Visit("bindings", &bindings); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("bindings", &BindingBlockNode::bindings); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { return equal(bindings, other->bindings); } @@ -810,13 +844,15 @@ class SeqExprNode : public ExprNode { Array blocks; Expr body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("blocks", &blocks); - v->Visit("body", &body); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("blocks", &SeqExprNode::blocks) + .def_ro("body", &SeqExprNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { return equal(blocks, other->blocks) && equal(body, other->body) && equal(struct_info_, other->struct_info_); @@ -874,14 +910,16 @@ class IfNode : public ExprNode { /*! \brief The expression evaluated when condition is false */ SeqExpr false_branch; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cond", &cond); - v->Visit("true_branch", &true_branch); - v->Visit("false_branch", &false_branch); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("cond", &IfNode::cond) + .def_ro("true_branch", &IfNode::true_branch) + .def_ro("false_branch", &IfNode::false_branch); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(cond, other->cond) && equal(true_branch, other->true_branch) && @@ -947,16 +985,17 @@ class FunctionNode : public BaseFuncNode { /*! \brief Whether the function is annotated as pure or not. */ bool is_pure; - void VisitAttrs(AttrVisitor* v) { - v->Visit("params", ¶ms); - v->Visit("body", &body); - v->Visit("is_pure", &is_pure); - v->Visit("ret_struct_info", &ret_struct_info); - v->Visit("attrs", &attrs); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("params", &FunctionNode::params) + .def_ro("body", &FunctionNode::body) + .def_ro("ret_struct_info", &FunctionNode::ret_struct_info) + .def_ro("is_pure", &FunctionNode::is_pure); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal.DefEqual(params, other->params) && equal(body, other->body) && @@ -1055,12 +1094,13 @@ class ExternFuncNode : public BaseFuncNode { /*! \brief The name of global symbol. */ String global_symbol; - void VisitAttrs(AttrVisitor* v) { - v->Visit("global_symbol", &global_symbol); - v->Visit("struct_info_", &struct_info_); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("global_symbol", &ExternFuncNode::global_symbol); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { return equal(global_symbol, other->global_symbol) && equal(struct_info_, other->struct_info_); } diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 5c4e646351d3..689586838009 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -19,6 +19,7 @@ #ifndef TVM_RELAX_STRUCT_INFO_H_ #define TVM_RELAX_STRUCT_INFO_H_ +#include #include #include #include @@ -34,7 +35,12 @@ namespace relax { */ class ObjectStructInfoNode : public StructInfoNode { public: - void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; } @@ -66,12 +72,15 @@ class PrimStructInfoNode : public StructInfoNode { /*! \brief Underlying data type of the primitive value */ DataType dtype; - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("dtype", &dtype); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("value", &PrimStructInfoNode::value) + .def_ro("dtype", &PrimStructInfoNode::dtype); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const { return equal(value, other->value) && equal(dtype, other->dtype); } @@ -116,12 +125,15 @@ class ShapeStructInfoNode : public StructInfoNode { /*! \return Whether the struct info contains unknown ndim. */ bool IsUnknownNdim() const { return ndim == kUnknownNDim; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("values", &values); - v->Visit("ndim", &ndim); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("values", &ShapeStructInfoNode::values) + .def_ro("ndim", &ShapeStructInfoNode::ndim); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const { return equal(values, other->values) && equal(ndim, other->ndim); } @@ -192,14 +204,17 @@ class TensorStructInfoNode : public StructInfoNode { return shape_sinfo->values; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); - v->Visit("vdevice", &vdevice); - v->Visit("ndim", &ndim); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("shape", &TensorStructInfoNode::shape) + .def_ro("dtype", &TensorStructInfoNode::dtype) + .def_ro("vdevice", &TensorStructInfoNode::vdevice) + .def_ro("ndim", &TensorStructInfoNode::ndim); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const { return equal(shape, other->shape) && equal(ndim, other->ndim) && equal(vdevice, other->vdevice) && equal(dtype, other->dtype); @@ -255,11 +270,13 @@ class TupleStructInfoNode : public StructInfoNode { /*! \brief The struct info of tuple fields. */ Array fields; - void VisitAttrs(AttrVisitor* v) { - v->Visit("fields", &fields); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("fields", &TupleStructInfoNode::fields); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const { return equal(fields, other->fields); } @@ -331,14 +348,17 @@ class FuncStructInfoNode : public StructInfoNode { */ bool IsOpaque() const { return !params.defined(); } - void VisitAttrs(AttrVisitor* v) { - v->Visit("params", ¶ms); - v->Visit("ret", &ret); - v->Visit("derive_func", &derive_func); - v->Visit("span", &span); - v->Visit("purity", &purity); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("params", &FuncStructInfoNode::params) + .def_ro("ret", &FuncStructInfoNode::ret) + .def_ro("derive_func", &FuncStructInfoNode::derive_func) + .def_ro("purity", &FuncStructInfoNode::purity); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const { return equal.DefEqual(params, other->params) && equal(ret, other->ret) && equal(purity, other->purity) && equal(derive_func, other->derive_func); diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h index 6a00487d69a0..d39dcc2f0002 100644 --- a/include/tvm/relax/tir_pattern.h +++ b/include/tvm/relax/tir_pattern.h @@ -25,6 +25,7 @@ #ifndef TVM_RELAX_TIR_PATTERN_H_ #define TVM_RELAX_TIR_PATTERN_H_ +#include #include namespace tvm { @@ -43,11 +44,17 @@ class MatchResultNode : public Object { Array symbol_values; /*! \brief The matched buffers of input and output. */ Array matched_buffers; - void VisitAttrs(AttrVisitor* v) { - v->Visit("pattern", &pattern); - v->Visit("symbol_values", &symbol_values); - v->Visit("matched_buffers", &matched_buffers); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pattern", &MatchResultNode::pattern) + .def_ro("symbol_values", &MatchResultNode::symbol_values) + .def_ro("matched_buffers", &MatchResultNode::matched_buffers); } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.MatchResult"; TVM_DECLARE_FINAL_OBJECT_INFO(MatchResultNode, Object); }; diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 6ccd693bff02..27f226042864 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -24,11 +24,13 @@ #ifndef TVM_RELAX_TRANSFORM_H_ #define TVM_RELAX_TRANSFORM_H_ +#include #include #include #include #include #include + namespace tvm { namespace relax { namespace transform { @@ -393,14 +395,18 @@ class FusionPatternNode : public Object { */ Optional attrs_getter; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("pattern", &pattern); - v->Visit("annotation_patterns", &annotation_patterns); - v->Visit("check", &check); - v->Visit("attrs_getter", &attrs_getter); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &FusionPatternNode::name) + .def_ro("pattern", &FusionPatternNode::pattern) + .def_ro("annotation_patterns", &FusionPatternNode::annotation_patterns) + .def_ro("check", &FusionPatternNode::check) + .def_ro("attrs_getter", &FusionPatternNode::attrs_getter); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.transform.FusionPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(FusionPatternNode, Object); }; @@ -450,14 +456,18 @@ class PatternCheckContextNode : public Object { */ Map value_to_bound_var; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("matched_expr", &matched_expr); - v->Visit("annotated_expr", &annotated_expr); - v->Visit("matched_bindings", &matched_bindings); - v->Visit("var_usages", &var_usages); - v->Visit("value_to_bound_var", &value_to_bound_var); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("matched_expr", &PatternCheckContextNode::matched_expr) + .def_ro("annotated_expr", &PatternCheckContextNode::annotated_expr) + .def_ro("matched_bindings", &PatternCheckContextNode::matched_bindings) + .def_ro("var_usages", &PatternCheckContextNode::var_usages) + .def_ro("value_to_bound_var", &PatternCheckContextNode::value_to_bound_var); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.transform.PatternCheckContext"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternCheckContextNode, Object); }; diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index bd75197bfe21..753330caf1d6 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -25,6 +25,7 @@ #define TVM_RELAX_TYPE_H_ #include +#include #include #include #include @@ -43,11 +44,13 @@ class ShapeTypeNode : public TypeNode { /*! \brief size of the shape. */ int ndim; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("ndim", &ndim); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("ndim", &ShapeTypeNode::ndim); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { return equal(ndim, other->ndim); } @@ -81,12 +84,15 @@ class TensorTypeNode : public TypeNode { /*! \brief The content data type, use void to denote the dtype is unknown. */ DataType dtype; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("ndim", &ndim); - v->Visit("dtype", &dtype); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("ndim", &TensorTypeNode::ndim) + .def_ro("dtype", &TensorTypeNode::dtype); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const { return equal(ndim, other->ndim) && equal(dtype, other->dtype); } @@ -131,7 +137,12 @@ using TensorType = TensorType; class ObjectTypeNode : public TypeNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; } @@ -150,7 +161,12 @@ class ObjectType : public Type { class PackedFuncTypeNode : public TypeNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const PackedFuncTypeNode* other, SEqualReducer equal) const { return true; } diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index b9a21b126c4f..c47385375fd7 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_DOC_H_ #define TVM_SCRIPT_PRINTER_DOC_H_ +#include #include #include #include @@ -62,9 +63,13 @@ class DocNode : public Object { */ mutable Array source_paths; - void VisitAttrs(AttrVisitor* v) { v->Visit("source_paths", &source_paths); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("source_paths", &DocNode::source_paths); + } static constexpr const char* _type_key = "script.printer.Doc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object); public: @@ -121,9 +126,13 @@ class ExprDocNode : public DocNode { Array kwargs_keys, // Array kwargs_values) const; - void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } static constexpr const char* _type_key = "script.printer.ExprDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(ExprDocNode, DocNode); }; @@ -163,12 +172,13 @@ class StmtDocNode : public DocNode { * */ mutable Optional comment{std::nullopt}; - void VisitAttrs(AttrVisitor* v) { - DocNode::VisitAttrs(v); - v->Visit("comment", &comment); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("comment", &StmtDocNode::comment); } static constexpr const char* _type_key = "script.printer.StmtDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(StmtDocNode, DocNode); }; @@ -196,12 +206,13 @@ class StmtBlockDocNode : public DocNode { /*! \brief The list of statements. */ Array stmts; - void VisitAttrs(AttrVisitor* v) { - DocNode::VisitAttrs(v); - v->Visit("stmts", &stmts); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("stmts", &StmtBlockDocNode::stmts); } static constexpr const char* _type_key = "script.printer.StmtBlockDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(StmtBlockDocNode, DocNode); }; @@ -237,12 +248,13 @@ class LiteralDocNode : public ExprDocNode { */ ObjectRef value; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("value", &value); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &LiteralDocNode::value); } static constexpr const char* _type_key = "script.printer.LiteralDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(LiteralDocNode, ExprDocNode); }; @@ -326,12 +338,13 @@ class IdDocNode : public ExprDocNode { /*! \brief The name of the identifier */ String name; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("name", &name); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name", &IdDocNode::name); } static constexpr const char* _type_key = "script.printer.IdDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(IdDocNode, ExprDocNode); }; @@ -363,13 +376,15 @@ class AttrAccessDocNode : public ExprDocNode { /*! \brief The attribute to be accessed */ String name; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("value", &value); - v->Visit("name", &name); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("value", &AttrAccessDocNode::value) + .def_ro("name", &AttrAccessDocNode::name); } static constexpr const char* _type_key = "script.printer.AttrAccessDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(AttrAccessDocNode, ExprDocNode); }; @@ -407,13 +422,15 @@ class IndexDocNode : public ExprDocNode { */ Array indices; // Each element is union of: Slice / ExprDoc - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("value", &value); - v->Visit("indices", &indices); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("value", &IndexDocNode::value) + .def_ro("indices", &IndexDocNode::indices); } static constexpr const char* _type_key = "script.printer.IndexDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(IndexDocNode, ExprDocNode); }; @@ -454,15 +471,17 @@ class CallDocNode : public ExprDocNode { */ Array kwargs_values; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("callee", &callee); - v->Visit("args", &args); - v->Visit("kwargs_keys", &kwargs_keys); - v->Visit("kwargs_values", &kwargs_values); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("callee", &CallDocNode::callee) + .def_ro("args", &CallDocNode::args) + .def_ro("kwargs_keys", &CallDocNode::kwargs_keys) + .def_ro("kwargs_values", &CallDocNode::kwargs_values); } static constexpr const char* _type_key = "script.printer.CallDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(CallDocNode, ExprDocNode); }; @@ -538,13 +557,15 @@ class OperationDocNode : public ExprDocNode { /*! \brief Operands of this expression */ Array operands; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("kind", &kind); - v->Visit("operands", &operands); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("kind", &OperationDocNode::kind) + .def_ro("operands", &OperationDocNode::operands); } static constexpr const char* _type_key = "script.printer.OperationDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(OperationDocNode, ExprDocNode); }; @@ -579,13 +600,15 @@ class LambdaDocNode : public ExprDocNode { /*! \brief The body of this anonymous function */ ExprDoc body{nullptr}; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("args", &args); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("args", &LambdaDocNode::args) + .def_ro("body", &LambdaDocNode::body); } static constexpr const char* _type_key = "script.printer.LambdaDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, ExprDocNode); }; @@ -615,12 +638,13 @@ class TupleDocNode : public ExprDocNode { /*! \brief Elements of tuple */ Array elements; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("elements", &elements); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("elements", &TupleDocNode::elements); } static constexpr const char* _type_key = "script.printer.TupleDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(TupleDocNode, ExprDocNode); }; @@ -653,12 +677,13 @@ class ListDocNode : public ExprDocNode { /*! \brief Elements of list */ Array elements; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("elements", &elements); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("elements", &ListDocNode::elements); } static constexpr const char* _type_key = "script.printer.ListDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ListDocNode, ExprDocNode); }; @@ -698,13 +723,15 @@ class DictDocNode : public ExprDocNode { */ Array values; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("keys", &keys); - v->Visit("values", &values); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("keys", &DictDocNode::keys) + .def_ro("values", &DictDocNode::values); } static constexpr const char* _type_key = "script.printer.DictDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(DictDocNode, ExprDocNode); }; @@ -744,14 +771,16 @@ class SliceDocNode : public DocNode { /*! \brief The step of slice */ Optional step; - void VisitAttrs(AttrVisitor* v) { - DocNode::VisitAttrs(v); - v->Visit("start", &start); - v->Visit("stop", &stop); - v->Visit("step", &step); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("start", &SliceDocNode::start) + .def_ro("stop", &SliceDocNode::stop) + .def_ro("step", &SliceDocNode::step); } static constexpr const char* _type_key = "script.printer.SliceDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(SliceDocNode, DocNode); }; @@ -790,14 +819,16 @@ class AssignDocNode : public StmtDocNode { /*! \brief The type annotation of this assignment. */ Optional annotation; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("lhs", &lhs); - v->Visit("rhs", &rhs); - v->Visit("annotation", &annotation); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("lhs", &AssignDocNode::lhs) + .def_ro("rhs", &AssignDocNode::rhs) + .def_ro("annotation", &AssignDocNode::annotation); } static constexpr const char* _type_key = "script.printer.AssignDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(AssignDocNode, StmtDocNode); }; @@ -832,14 +863,16 @@ class IfDocNode : public StmtDocNode { /*! \brief The else branch of the if-then-else statement. */ Array else_branch; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("predicate", &predicate); - v->Visit("then_branch", &then_branch); - v->Visit("else_branch", &else_branch); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("predicate", &IfDocNode::predicate) + .def_ro("then_branch", &IfDocNode::then_branch) + .def_ro("else_branch", &IfDocNode::else_branch); } static constexpr const char* _type_key = "script.printer.IfDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(IfDocNode, StmtDocNode); }; @@ -872,13 +905,15 @@ class WhileDocNode : public StmtDocNode { /*! \brief The body of the while statement. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("predicate", &predicate); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("predicate", &WhileDocNode::predicate) + .def_ro("body", &WhileDocNode::body); } static constexpr const char* _type_key = "script.printer.WhileDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(WhileDocNode, StmtDocNode); }; @@ -916,14 +951,16 @@ class ForDocNode : public StmtDocNode { /*! \brief The body of the for statement. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("lhs", &lhs); - v->Visit("rhs", &rhs); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("lhs", &ForDocNode::lhs) + .def_ro("rhs", &ForDocNode::rhs) + .def_ro("body", &ForDocNode::body); } static constexpr const char* _type_key = "script.printer.ForDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ForDocNode, StmtDocNode); }; @@ -963,14 +1000,16 @@ class ScopeDocNode : public StmtDocNode { /*! \brief The body of the scope doc. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("lhs", &lhs); - v->Visit("rhs", &rhs); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("lhs", &ScopeDocNode::lhs) + .def_ro("rhs", &ScopeDocNode::rhs) + .def_ro("body", &ScopeDocNode::body); } static constexpr const char* _type_key = "script.printer.ScopeDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ScopeDocNode, StmtDocNode); }; @@ -1009,12 +1048,13 @@ class ExprStmtDocNode : public StmtDocNode { /*! \brief The expression represented by this doc. */ ExprDoc expr{nullptr}; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("expr", &expr); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("expr", &ExprStmtDocNode::expr); } static constexpr const char* _type_key = "script.printer.ExprStmtDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ExprStmtDocNode, StmtDocNode); }; @@ -1045,13 +1085,15 @@ class AssertDocNode : public StmtDocNode { /*! \brief The optional error message when assertion failed. */ Optional msg{std::nullopt}; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("test", &test); - v->Visit("msg", &msg); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("test", &AssertDocNode::test) + .def_ro("msg", &AssertDocNode::msg); } static constexpr const char* _type_key = "script.printer.AssertDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(AssertDocNode, StmtDocNode); }; @@ -1081,12 +1123,13 @@ class ReturnDocNode : public StmtDocNode { /*! \brief The value to return. */ ExprDoc value{nullptr}; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("value", &value); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &ReturnDocNode::value); } static constexpr const char* _type_key = "script.printer.ReturnDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ReturnDocNode, StmtDocNode); }; @@ -1129,16 +1172,18 @@ class FunctionDocNode : public StmtDocNode { /*! \brief The body of function. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("args", &args); - v->Visit("decorators", &decorators); - v->Visit("return_type", &return_type); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &FunctionDocNode::name) + .def_ro("args", &FunctionDocNode::args) + .def_ro("decorators", &FunctionDocNode::decorators) + .def_ro("return_type", &FunctionDocNode::return_type) + .def_ro("body", &FunctionDocNode::body); } static constexpr const char* _type_key = "script.printer.FunctionDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(FunctionDocNode, StmtDocNode); }; @@ -1176,14 +1221,16 @@ class ClassDocNode : public StmtDocNode { /*! \brief The body of class. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("decorators", &decorators); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &ClassDocNode::name) + .def_ro("decorators", &ClassDocNode::decorators) + .def_ro("body", &ClassDocNode::body); } static constexpr const char* _type_key = "script.printer.ClassDoc"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ClassDocNode, StmtDocNode); }; @@ -1211,6 +1258,11 @@ class ClassDoc : public StmtDoc { */ class CommentDocNode : public StmtDocNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr const char* _type_key = "script.printer.CommentDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(CommentDocNode, StmtDocNode); }; @@ -1233,6 +1285,11 @@ class CommentDoc : public StmtDoc { */ class DocStringDocNode : public StmtDocNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr const char* _type_key = "script.printer.DocStringDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(DocStringDocNode, StmtDocNode); }; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 9211064526ed..338b65e1cf6f 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_ +#include #include #include #include @@ -52,13 +53,13 @@ class FrameNode : public Object { /*! The callbacks that are going to be invoked when the frame exits */ std::vector> callbacks; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("stmts", &stmts); - // `d` is not visited - // `callbacks` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("stmts", &FrameNode::stmts); } static constexpr const char* _type_key = "script.printer.Frame"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); public: @@ -154,17 +155,15 @@ class IRDocsifierNode : public Object { /*! \brief The IR usages for headers printing */ std::unordered_set ir_usage; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("frames", &frames); - v->Visit("dispatch_tokens", &dispatch_tokens); - // `obj2info` is not visited - // `metadata` is not visited - // `defined_names` is not visited - // `common_prefix` is not visited - // `ir_usage` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("frames", &IRDocsifierNode::frames) + .def_ro("dispatch_tokens", &IRDocsifierNode::dispatch_tokens); } static constexpr const char* _type_key = "script.printer.IRDocsifier"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(IRDocsifierNode, Object); public: diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index ddebb3547954..9af2c8e49732 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_TAG_H_ #define TVM_TARGET_TAG_H_ +#include #include #include #include @@ -40,12 +41,15 @@ class TargetTagNode : public Object { /*! \brief Config map to generate the target */ Map config; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("config", &config); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &TargetTagNode::name) + .def_ro("config", &TargetTagNode::config); } static constexpr const char* _type_key = "TargetTag"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(TargetTagNode, Object); private: diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 86e90a7ce2db..2d6b1834e228 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -24,8 +24,12 @@ #ifndef TVM_TARGET_TARGET_H_ #define TVM_TARGET_TARGET_H_ +#include #include +#include +#include #include +#include #include #include @@ -89,13 +93,15 @@ class TargetNode : public Object { */ String ToDebugString() const; - void VisitAttrs(AttrVisitor* v) { - v->Visit("kind", &kind); - v->Visit("tag", &tag); - v->Visit("keys", &keys); - v->Visit("attrs", &attrs); - v->Visit("features", &features); - v->Visit("host", &host); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("kind", &TargetNode::kind) + .def_ro("tag", &TargetNode::tag) + .def_ro("keys", &TargetNode::keys) + .def_ro("attrs", &TargetNode::attrs) + .def_ro("features", &TargetNode::features) + .def_ro("host", &TargetNode::host); } /*! @@ -171,6 +177,7 @@ class TargetNode : public Object { void SHashReduce(SHashReducer hash_reduce) const; static constexpr const char* _type_key = "Target"; + static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); diff --git a/include/tvm/target/target_info.h b/include/tvm/target/target_info.h index 946161f905f3..0c1c4abf0158 100644 --- a/include/tvm/target/target_info.h +++ b/include/tvm/target/target_info.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_TARGET_INFO_H_ #define TVM_TARGET_TARGET_INFO_H_ +#include #include #include @@ -48,14 +49,17 @@ class MemoryInfoNode : public Object { */ PrimExpr head_address; - void VisitAttrs(AttrVisitor* v) { - v->Visit("unit_bits", &unit_bits); - v->Visit("max_num_bits", &max_num_bits); - v->Visit("max_simd_bits", &max_simd_bits); - v->Visit("head_address", &head_address); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("unit_bits", &MemoryInfoNode::unit_bits) + .def_ro("max_num_bits", &MemoryInfoNode::max_num_bits) + .def_ro("max_simd_bits", &MemoryInfoNode::max_simd_bits) + .def_ro("head_address", &MemoryInfoNode::head_address); } static constexpr const char* _type_key = "MemoryInfo"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object); }; diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index f652424800dc..9875ceef3367 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -24,6 +24,8 @@ #ifndef TVM_TARGET_TARGET_KIND_H_ #define TVM_TARGET_TARGET_KIND_H_ +#include +#include #include #include @@ -75,13 +77,16 @@ class TargetKindNode : public Object { /*! \brief Function used to parse a JSON target during creation */ FTVMTargetParser target_parser; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("default_device_type", &default_device_type); - v->Visit("default_keys", &default_keys); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &TargetKindNode::name) + .def_ro("default_device_type", &TargetKindNode::default_device_type) + .def_ro("default_keys", &TargetKindNode::default_keys); } static constexpr const char* _type_key = "TargetKind"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object); private: diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index e220421a96cc..e92409df53a5 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -25,6 +25,7 @@ #define TVM_TE_OPERATION_H_ #include +#include #include #include #include @@ -82,7 +83,16 @@ class TVM_DLL OperationNode : public Object { */ virtual Array InputTensors() const = 0; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &OperationNode::name) + .def_ro("tag", &OperationNode::tag) + .def_ro("attrs", &OperationNode::attrs); + } + static constexpr const char* _type_key = "Operation"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object); }; @@ -102,15 +112,15 @@ class PlaceholderOpNode : public OperationNode { Array output_shape(size_t i) const final; Array InputTensors() const final; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("tag", &tag); - v->Visit("attrs", &attrs); - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("shape", &PlaceholderOpNode::shape) + .def_ro("dtype", &PlaceholderOpNode::dtype); } static constexpr const char* _type_key = "PlaceholderOp"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; @@ -138,7 +148,15 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { // override functions Array output_shape(size_t idx) const final; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("axis", &BaseComputeOpNode::axis) + .def_ro("reduce_axis", &BaseComputeOpNode::reduce_axis); + } + static constexpr const char* _type_key = "BaseComputeOp"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; @@ -156,16 +174,13 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { DataType output_dtype(size_t i) const final; Array InputTensors() const final; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("tag", &tag); - v->Visit("attrs", &attrs); - v->Visit("axis", &axis); - v->Visit("reduce_axis", &reduce_axis); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("body", &ComputeOpNode::body); } static constexpr const char* _type_key = "ComputeOp"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; @@ -218,19 +233,19 @@ class ScanOpNode : public OperationNode { Array output_shape(size_t i) const final; Array InputTensors() const final; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("tag", &tag); - v->Visit("attrs", &attrs); - v->Visit("scan_axis", &scan_axis); - v->Visit("init", &init); - v->Visit("update", &update); - v->Visit("state_placeholder", &state_placeholder); - v->Visit("inputs", &inputs); - v->Visit("spatial_axis_", &spatial_axis_); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("scan_axis", &ScanOpNode::scan_axis) + .def_ro("init", &ScanOpNode::init) + .def_ro("update", &ScanOpNode::update) + .def_ro("state_placeholder", &ScanOpNode::state_placeholder) + .def_ro("inputs", &ScanOpNode::inputs) + .def_ro("spatial_axis_", &ScanOpNode::spatial_axis_); } static constexpr const char* _type_key = "ScanOp"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; @@ -269,17 +284,17 @@ class ExternOpNode : public OperationNode { Array output_shape(size_t i) const final; Array InputTensors() const final; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("tag", &tag); - v->Visit("attrs", &attrs); - v->Visit("inputs", &inputs); - v->Visit("input_placeholders", &input_placeholders); - v->Visit("output_placeholders", &output_placeholders); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("inputs", &ExternOpNode::inputs) + .def_ro("input_placeholders", &ExternOpNode::input_placeholders) + .def_ro("output_placeholders", &ExternOpNode::output_placeholders) + .def_ro("body", &ExternOpNode::body); } static constexpr const char* _type_key = "ExternOp"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index b89eb1a86196..56dce360ccf1 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -25,6 +25,7 @@ #define TVM_TE_TENSOR_H_ #include +#include #include #include @@ -76,12 +77,7 @@ class TensorNode : public DataProducerNode { /*! \brief the output index from source operation */ int value_index{0}; - void VisitAttrs(AttrVisitor* v) { - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); - v->Visit("op", &op); - v->Visit("value_index", &value_index); - } + static void RegisterReflection(); Array GetShape() const final { return shape; } @@ -92,6 +88,7 @@ class TensorNode : public DataProducerNode { TVM_DLL String GetNameHint() const final; static constexpr const char* _type_key = "te.Tensor"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); }; diff --git a/include/tvm/tir/block_dependence_info.h b/include/tvm/tir/block_dependence_info.h index 0e0d22c8cf51..c45f095dfe43 100644 --- a/include/tvm/tir/block_dependence_info.h +++ b/include/tvm/tir/block_dependence_info.h @@ -31,6 +31,7 @@ #ifndef TVM_TIR_BLOCK_DEPENDENCE_INFO_H_ #define TVM_TIR_BLOCK_DEPENDENCE_INFO_H_ +#include #include #include @@ -60,7 +61,12 @@ class BlockDependenceInfoNode : public Object { /*! \brief The reverse mapping from block/for-loop to their corresponding srefs */ std::unordered_map stmt2ref; - void VisitAttrs(AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "tir.BlockDependenceInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockDependenceInfoNode, Object); diff --git a/include/tvm/tir/block_scope.h b/include/tvm/tir/block_scope.h index 9a877458bdd8..0035a230c173 100644 --- a/include/tvm/tir/block_scope.h +++ b/include/tvm/tir/block_scope.h @@ -67,12 +67,13 @@ class StmtSRefNode : public Object { */ int64_t seq_index; - void VisitAttrs(AttrVisitor* v) { - // `stmt` is not visited - // `parent` is not visited - v->Visit("seq_index", &seq_index); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("seq_index", &StmtSRefNode::seq_index); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.StmtSRef"; TVM_DECLARE_FINAL_OBJECT_INFO(StmtSRefNode, Object); @@ -220,12 +221,16 @@ class DependencyNode : public Object { /*! \brief The dependency kind */ DepKind kind; - void VisitAttrs(AttrVisitor* v) { - v->Visit("src", &src); - v->Visit("dst", &dst); - v->Visit("kind", &kind); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &DependencyNode::src) + .def_ro("dst", &DependencyNode::dst) + .def_ro("kind", &DependencyNode::kind); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.Dependency"; TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object); }; @@ -267,7 +272,11 @@ class BlockScopeNode : public Object { /*! \brief The mapping from the buffer to the blocks who write it */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; - void VisitAttrs(AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register as they are not visited + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "tir.BlockScope"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, Object); diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index e0a197d41ff8..12afbc510101 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -25,6 +25,7 @@ #define TVM_TIR_BUFFER_H_ #include +#include #include #include #include @@ -111,20 +112,24 @@ class BufferNode : public Object { /*! \brief constructor */ BufferNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("data", &data); - v->Visit("dtype", &dtype); - v->Visit("shape", &shape); - v->Visit("strides", &strides); - v->Visit("axis_separators", &axis_separators); - v->Visit("elem_offset", &elem_offset); - v->Visit("name", &name); - v->Visit("data_alignment", &data_alignment); - v->Visit("offset_factor", &offset_factor); - v->Visit("buffer_type", &buffer_type); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("data", &BufferNode::data) + .def_ro("dtype", &BufferNode::dtype) + .def_ro("shape", &BufferNode::shape) + .def_ro("strides", &BufferNode::strides) + .def_ro("axis_separators", &BufferNode::axis_separators) + .def_ro("elem_offset", &BufferNode::elem_offset) + .def_ro("name", &BufferNode::name) + .def_ro("data_alignment", &BufferNode::data_alignment) + .def_ro("offset_factor", &BufferNode::offset_factor) + .def_ro("buffer_type", &BufferNode::buffer_type) + .def_ro("span", &BufferNode::span); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { // Use DefEqual as buffer can define variables in its semantics, // skip name as name is not important. diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 7aefef6e485b..1643ccb60bb2 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -25,6 +25,7 @@ #ifndef TVM_TIR_DATA_LAYOUT_H_ #define TVM_TIR_DATA_LAYOUT_H_ +#include #include #include @@ -107,11 +108,15 @@ class LayoutNode : public Object { */ Array axes; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("axes", &axes); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &LayoutNode::name) + .def_ro("axes", &LayoutNode::axes); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.Layout"; TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); }; @@ -310,15 +315,19 @@ class BijectiveLayoutNode : public Object { /*! \brief The destination layout */ Layout dst_layout; - void VisitAttrs(AttrVisitor* v) { - v->Visit("src_layout", &src_layout); - v->Visit("dst_layout", &dst_layout); - v->Visit("index_forward_rule", &index_forward_rule); - v->Visit("index_backward_rule", &index_backward_rule); - v->Visit("shape_forward_rule", &shape_forward_rule); - v->Visit("shape_backward_rule", &shape_backward_rule); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src_layout", &BijectiveLayoutNode::src_layout) + .def_ro("dst_layout", &BijectiveLayoutNode::dst_layout) + .def_ro("index_forward_rule", &BijectiveLayoutNode::index_forward_rule) + .def_ro("index_backward_rule", &BijectiveLayoutNode::index_backward_rule) + .def_ro("shape_forward_rule", &BijectiveLayoutNode::shape_forward_rule) + .def_ro("shape_backward_rule", &BijectiveLayoutNode::shape_backward_rule); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.BijectiveLayout"; TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 92fe19e8aa35..6bcb35d38dba 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -100,15 +100,17 @@ class PrimFuncNode : public BaseFuncNode { */ Map buffer_map; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("params", ¶ms); - v->Visit("body", &body); - v->Visit("ret_type", &ret_type); - v->Visit("buffer_map", &buffer_map); - v->Visit("attrs", &attrs); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("params", &PrimFuncNode::params) + .def_ro("body", &PrimFuncNode::body) + .def_ro("ret_type", &PrimFuncNode::ret_type) + .def_ro("buffer_map", &PrimFuncNode::buffer_map); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && @@ -180,11 +182,15 @@ class TensorIntrinNode : public Object { /*! \brief The function of the implementation for the execution. */ PrimFunc impl; - void VisitAttrs(AttrVisitor* v) { - v->Visit("desc", &desc); - v->Visit("impl", &impl); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("desc", &TensorIntrinNode::desc) + .def_ro("impl", &TensorIntrinNode::impl); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.TensorIntrin"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); }; diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 1a5bdd8e4018..45c0c50a4b0e 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -151,12 +151,16 @@ class IndexMapNode : public Object { String ToPythonString( const std::function(const Var& var)>& f_name_map = nullptr) const; - void VisitAttrs(AttrVisitor* v) { - v->Visit("initial_indices", &initial_indices); - v->Visit("final_indices", &final_indices); - v->Visit("inverse_index_map", &inverse_index_map); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("initial_indices", &IndexMapNode::initial_indices) + .def_ro("final_indices", &IndexMapNode::final_indices) + .def_ro("inverse_index_map", &IndexMapNode::inverse_index_map); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const { return equal.DefEqual(initial_indices, other->initial_indices) && equal(final_indices, other->final_indices); diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index fe054865b738..bf96f6e0363a 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -19,6 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_INSTRUCTION_H_ #define TVM_TIR_SCHEDULE_INSTRUCTION_H_ +#include #include #include @@ -111,15 +112,15 @@ class InstructionKindNode : public runtime::Object { */ FInstructionAttrsFromJSON f_attrs_from_json{nullptr}; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("_is_pure", &is_pure); - // not visited: f_apply_to_schedule - // not visited: f_as_python - // not visited: f_attrs_as_json - // not visited: f_attrs_from_json + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &InstructionKindNode::name) + .def_ro("_is_pure", &InstructionKindNode::is_pure); } + static constexpr bool _type_has_method_visit_attrs = false; + /*! \brief Checks if the instruction kind is EnterPostproc */ bool IsPostproc() const; @@ -173,13 +174,17 @@ class InstructionNode : public runtime::Object { */ Array outputs; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("kind", &kind); - v->Visit("inputs", &inputs); - v->Visit("attrs", &attrs); - v->Visit("outputs", &outputs); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("kind", &InstructionNode::kind) + .def_ro("inputs", &InstructionNode::inputs) + .def_ro("attrs", &InstructionNode::attrs) + .def_ro("outputs", &InstructionNode::outputs); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.Instruction"; TVM_DECLARE_FINAL_OBJECT_INFO(InstructionNode, runtime::Object); }; diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 1a24644b5202..6a303f6a47fd 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -50,7 +50,12 @@ enum class BufferIndexType : int32_t { /*! \brief A random variable that evaluates to a TensorIR block */ class BlockRVNode : public runtime::Object { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register as they are not visited + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.BlockRV"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object); }; @@ -71,7 +76,12 @@ class BlockRV : public runtime::ObjectRef { /*! \brief A random variable that evaluates to a TensorIR for loop */ class LoopRVNode : public runtime::Object { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register as they are not visited + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.LoopRV"; TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object); }; diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index d2d90812dd37..6a551d9923bd 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -23,6 +23,7 @@ #ifndef TVM_TIR_SCHEDULE_STATE_H_ #define TVM_TIR_SCHEDULE_STATE_H_ +#include #include #include #include @@ -118,13 +119,16 @@ class ScheduleStateNode : public Object { */ bool enable_check; - void VisitAttrs(AttrVisitor* v) { - v->Visit("mod", &mod); - // `block_info` is not visited - // `stmt2ref` is not visited - v->Visit("debug_mask", &debug_mask); - v->Visit("enable_check", &enable_check); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("mod", &ScheduleStateNode::mod) + .def_ro("debug_mask", &ScheduleStateNode::debug_mask) + .def_ro("enable_check", &ScheduleStateNode::enable_check); } + + static constexpr bool _type_has_method_visit_attrs = false; + /*! * \brief Replace the part of the AST, as being pointed to by `src_sref`, * with a specific statement `tgt_stmt`, and maintain the sref tree accordingly. diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index fca5966e198b..c5858842a095 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -62,11 +62,15 @@ class TraceNode : public runtime::Object { /*! \brief The random decisions made upon those instructions */ Map decisions; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("insts", &insts); - v->Visit("decisions", &decisions); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("insts", &TraceNode::insts) + .def_ro("decisions", &TraceNode::decisions); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.Trace"; TVM_DECLARE_FINAL_OBJECT_INFO(TraceNode, runtime::Object); diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 640fff7af557..8591918db69d 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -297,6 +297,9 @@ def _add_class_attrs_by_reflection(int type_index, object cls): else None ) name = py_str(PyBytes_FromStringAndSize(field.name.data, field.name.size)) + if hasattr(cls, name): + # skip already defined attributes + continue setattr(cls, name, property(getter, setter, doc=doc)) for i in range(num_methods): @@ -320,6 +323,10 @@ def _add_class_attrs_by_reflection(int type_index, object cls): method_pyfunc.__doc__ = doc method_pyfunc.__name__ = name + if hasattr(cls, name): + # skip already defined attributes + continue + setattr(cls, name, method_pyfunc) return cls diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index d3b7b30628a1..f9ade53a3516 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -43,6 +43,8 @@ using tir::is_zero; using tir::make_const; using tir::make_zero; +TVM_FFI_STATIC_INIT_BLOCK({ IntervalSetNode::RegisterReflection(); }); + PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index dc40fa9d4dee..80e5b9ce874d 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -25,6 +25,7 @@ #define TVM_ARITH_INTERVAL_SET_H_ #include +#include #include #include @@ -49,11 +50,15 @@ class IntervalSetNode : public IntSetNode { PrimExpr max_value; // visitor overload. - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("min_value", &min_value); - v->Visit("max_value", &max_value); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("min_value", &IntervalSetNode::min_value) + .def_ro("max_value", &IntervalSetNode::max_value); } + static constexpr bool _type_has_method_visit_attrs = false; + /*! \return Whether the interval has upper bound. */ bool HasUpperBound() const { return !is_pos_inf(max_value) && !IsEmpty(); } /*! \return Whether the interval has lower bound. */ diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index e514ad1b1ad7..1af1bb2e39bf 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -44,6 +44,8 @@ namespace arith { #ifdef TVM_MLIR_VERSION #if TVM_MLIR_VERSION >= 150 + +TVM_FFI_STATIC_INIT_BLOCK({ PresburgerSetNode::RegisterReflection(); }); using namespace tir; static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) { diff --git a/src/arith/presburger_set.h b/src/arith/presburger_set.h index d580e23a6d5a..496603b9d839 100644 --- a/src/arith/presburger_set.h +++ b/src/arith/presburger_set.h @@ -33,6 +33,7 @@ #endif #include +#include #include #include @@ -71,7 +72,12 @@ class PresburgerSetNode : public IntSetNode { PresburgerSpace space; // visitor overload. - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; /*! * \brief Do inplace union with given disjunct @@ -150,7 +156,12 @@ class PresburgerSet : public IntSet { class PresburgerSetNode : public IntSetNode { public: // dummy visitor overload. - void VisitAttrs(tvm::AttrVisitor* v) { LOG(FATAL) << "MLIR is not enabled!"; } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "arith.PresburgerSet"; TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index c911124700fe..ae6985ddf67e 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -44,6 +44,8 @@ namespace arith { using namespace tir; +TVM_FFI_STATIC_INIT_BLOCK({ RewriteSimplifierStatsNode::RegisterReflection(); }); + // Note: When using matches_one_of or PMatchesOneOf alongside these // macros, be careful which patterns are used in the ResExpr. While // the different source expressions may be in terms of different PVar, diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 1a53bef45002..f5d5b4c47f95 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -25,6 +25,7 @@ #define TVM_ARITH_REWRITE_SIMPLIFY_H_ #include +#include #include #include @@ -53,15 +54,19 @@ struct RewriteSimplifierStatsNode : Object { int64_t max_recursive_depth{0}; int64_t num_recursive_rewrites{0}; - void VisitAttrs(AttrVisitor* v) { - v->Visit("nodes_visited", &nodes_visited); - v->Visit("constraints_entered", &constraints_entered); - v->Visit("rewrites_attempted", &rewrites_attempted); - v->Visit("rewrites_performed", &rewrites_performed); - v->Visit("max_recursive_depth", &max_recursive_depth); - v->Visit("num_recursive_rewrites", &num_recursive_rewrites); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("nodes_visited", &RewriteSimplifierStatsNode::nodes_visited) + .def_ro("constraints_entered", &RewriteSimplifierStatsNode::constraints_entered) + .def_ro("rewrites_attempted", &RewriteSimplifierStatsNode::rewrites_attempted) + .def_ro("rewrites_performed", &RewriteSimplifierStatsNode::rewrites_performed) + .def_ro("max_recursive_depth", &RewriteSimplifierStatsNode::max_recursive_depth) + .def_ro("num_recursive_rewrites", &RewriteSimplifierStatsNode::num_recursive_rewrites); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "arith.RewriteSimplifierStats"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object); }; diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index d38beab5b4ed..32d9a623eafa 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -1419,14 +1419,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +TVM_FFI_STATIC_INIT_BLOCK({ + MSCTensorNode::RegisterReflection(); + BaseJointNode::RegisterReflection(); + MSCJointNode::RegisterReflection(); + MSCPrimNode::RegisterReflection(); + WeightJointNode::RegisterReflection(); + BaseGraphNode::RegisterReflection(); + MSCGraphNode::RegisterReflection(); + WeightGraphNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(MSCTensorNode); +TVM_REGISTER_NODE_TYPE(BaseJointNode); + TVM_REGISTER_NODE_TYPE(MSCJointNode); TVM_REGISTER_NODE_TYPE(MSCPrimNode); TVM_REGISTER_NODE_TYPE(WeightJointNode); +TVM_REGISTER_NODE_TYPE(BaseGraphNode); + TVM_REGISTER_NODE_TYPE(MSCGraphNode); TVM_REGISTER_NODE_TYPE(WeightGraphNode); diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index 1e22e96ac951..9f6ba07f0f5b 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -25,6 +25,7 @@ #define TVM_CONTRIB_MSC_CORE_IR_GRAPH_H_ #include +#include #include #include @@ -375,15 +376,19 @@ class MSCTensorNode : public Object { /*! \brief Get name of the dtype. */ const String DTypeName() const; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("alias", &alias); - v->Visit("dtype", &dtype); - v->Visit("layout", &layout); - v->Visit("shape", &shape); - v->Visit("prims", &prims); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &MSCTensorNode::name) + .def_ro("alias", &MSCTensorNode::alias) + .def_ro("dtype", &MSCTensorNode::dtype) + .def_ro("layout", &MSCTensorNode::layout) + .def_ro("shape", &MSCTensorNode::shape) + .def_ro("prims", &MSCTensorNode::prims); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const MSCTensorNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(dtype, other->dtype) && equal(shape, other->shape) && equal(layout, other->layout) && equal(prims, other->prims); @@ -486,15 +491,19 @@ class BaseJointNode : public Object { return val; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("index", &index); - v->Visit("name", &name); - v->Visit("shared_ref", &shared_ref); - v->Visit("attrs", &attrs); - v->Visit("parents", &parents); - v->Visit("children", &children); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("index", &BaseJointNode::index) + .def_ro("name", &BaseJointNode::name) + .def_ro("shared_ref", &BaseJointNode::shared_ref) + .def_ro("attrs", &BaseJointNode::attrs) + .def_ro("parents", &BaseJointNode::parents) + .def_ro("children", &BaseJointNode::children); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BaseJointNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(shared_ref, other->shared_ref) && equal(attrs, other->attrs) && equal(parents, other->parents) && @@ -570,15 +579,18 @@ class MSCJointNode : public BaseJointNode { const std::pair ProducerAndIdxOf(const String& name) const; const std::pair ProducerAndIdxOf(const MSCTensor& input) const; - void VisitAttrs(AttrVisitor* v) { - BaseJointNode::VisitAttrs(v); - v->Visit("optype", &optype); - v->Visit("scope", &scope); - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - v->Visit("weights", &weights); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("optype", &MSCJointNode::optype) + .def_ro("scope", &MSCJointNode::scope) + .def_ro("inputs", &MSCJointNode::inputs) + .def_ro("outputs", &MSCJointNode::outputs) + .def_ro("weights", &MSCJointNode::weights); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const MSCJointNode* other, SEqualReducer equal) const { return BaseJointNode::SEqualReduce(other, equal) && equal(optype, other->optype) && equal(scope, other->scope) && equal(inputs, other->inputs) && @@ -658,11 +670,13 @@ class MSCPrimNode : public BaseJointNode { /*! \brief Get child from the prim. */ const MSCPrim ChildAt(int index) const; - void VisitAttrs(AttrVisitor* v) { - BaseJointNode::VisitAttrs(v); - v->Visit("optype", &optype); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("optype", &MSCPrimNode::optype); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const MSCPrimNode* other, SEqualReducer equal) const { return BaseJointNode::SEqualReduce(other, equal) && equal(optype, other->optype); } @@ -732,13 +746,16 @@ class WeightJointNode : public BaseJointNode { /*! \brief Get child from the node. */ const WeightJoint ChildAt(int index) const; - void VisitAttrs(AttrVisitor* v) { - BaseJointNode::VisitAttrs(v); - v->Visit("weight_type", &weight_type); - v->Visit("weight", &weight); - v->Visit("friends", &friends); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("weight_type", &WeightJointNode::weight_type) + .def_ro("weight", &WeightJointNode::weight) + .def_ro("friends", &WeightJointNode::friends); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const WeightJointNode* other, SEqualReducer equal) const { return BaseJointNode::SEqualReduce(other, equal) && equal(weight_type, other->weight_type) && equal(weight, other->weight) && equal(friends, other->friends); @@ -807,12 +824,16 @@ class BaseGraphNode : public Object { /*! \brief Check if node in the graph. */ const bool HasNode(const String& name) const; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("nodes", &nodes); - v->Visit("node_names", &node_names); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &BaseGraphNode::name) + .def_ro("nodes", &BaseGraphNode::nodes) + .def_ro("node_names", &BaseGraphNode::node_names); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BaseGraphNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(nodes, other->nodes) && equal(node_names, other->node_names); @@ -906,15 +927,18 @@ class MSCGraphNode : public BaseGraphNode { /*! \brief Analysis the graph and fill info. */ void AnalysisGraph(); - void VisitAttrs(AttrVisitor* v) { - BaseGraphNode::VisitAttrs(v); - v->Visit("prims", &prims); - v->Visit("prim_names", &prim_names); - v->Visit("input_names", &input_names); - v->Visit("output_names", &output_names); - v->Visit("weight_holders", &weight_holders); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("prims", &MSCGraphNode::prims) + .def_ro("prim_names", &MSCGraphNode::prim_names) + .def_ro("input_names", &MSCGraphNode::input_names) + .def_ro("output_names", &MSCGraphNode::output_names) + .def_ro("weight_holders", &MSCGraphNode::weight_holders); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const MSCGraphNode* other, SEqualReducer equal) const { return BaseGraphNode::SEqualReduce(other, equal) && equal(prims, other->prims) && equal(prim_names, other->prim_names) && equal(input_names, other->input_names) && @@ -986,7 +1010,12 @@ class WeightGraphNode : public BaseGraphNode { /*! \brief Export graph to prototxt. */ const String ToPrototxt() const; - void VisitAttrs(AttrVisitor* v) { BaseGraphNode::VisitAttrs(v); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const WeightGraphNode* other, SEqualReducer equal) const { return BaseGraphNode::SEqualReduce(other, equal); diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc index fc6000a20f3d..a5df533eb369 100644 --- a/src/contrib/msc/core/ir/plugin.cc +++ b/src/contrib/msc/core/ir/plugin.cc @@ -305,6 +305,13 @@ const Plugin GetPlugin(const String& name) { return PluginRegistry::Global()->Ge bool IsPlugin(const String& name) { return PluginRegistry::Global()->Registered(name); } +TVM_FFI_STATIC_INIT_BLOCK({ + PluginAttrNode::RegisterReflection(); + PluginTensorNode::RegisterReflection(); + PluginExternNode::RegisterReflection(); + PluginNode::RegisterReflection(); +}); + TVM_FFI_REGISTER_GLOBAL("msc.core.RegisterPlugin") .set_body_typed([](const String& name, const String& json_str) { PluginRegistry::Global()->Register(name, json_str); diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h index dc6f3be68dc4..6e2fc5ddce00 100644 --- a/src/contrib/msc/core/ir/plugin.h +++ b/src/contrib/msc/core/ir/plugin.h @@ -25,6 +25,7 @@ #define TVM_CONTRIB_MSC_CORE_IR_PLUGIN_H_ #include +#include #include #include @@ -268,13 +269,17 @@ class PluginAttrNode : public Object { /*! \brief Load attribute from json string. */ void FromJson(const std::string& json_str); - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("type", &type); - v->Visit("default_value", &default_value); - v->Visit("describe", &describe); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &PluginAttrNode::name) + .def_ro("type", &PluginAttrNode::type) + .def_ro("default_value", &PluginAttrNode::default_value) + .def_ro("describe", &PluginAttrNode::describe); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PluginAttrNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(type, other->type) && equal(default_value, other->default_value) && equal(describe, other->describe); @@ -345,14 +350,18 @@ class PluginTensorNode : public Object { /*! \brief Load tensor from json string. */ void FromJson(const std::string& json_str); - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("dtype", &dtype); - v->Visit("ndim", &ndim); - v->Visit("device", &device); - v->Visit("describe", &describe); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &PluginTensorNode::name) + .def_ro("dtype", &PluginTensorNode::dtype) + .def_ro("ndim", &PluginTensorNode::ndim) + .def_ro("device", &PluginTensorNode::device) + .def_ro("describe", &PluginTensorNode::describe); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PluginTensorNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(dtype, other->dtype) && equal(ndim, other->ndim) && equal(device, other->device) && equal(describe, other->describe); @@ -425,14 +434,18 @@ class PluginExternNode : public Object { /*! \brief Load extern from json string. */ void FromJson(const std::string& json_str); - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("header", &header); - v->Visit("source", &source); - v->Visit("lib", &lib); - v->Visit("describe", &describe); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &PluginExternNode::name) + .def_ro("header", &PluginExternNode::header) + .def_ro("source", &PluginExternNode::source) + .def_ro("lib", &PluginExternNode::lib) + .def_ro("describe", &PluginExternNode::describe); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PluginExternNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(header, other->header) && equal(source, other->source) && equal(lib, other->lib) && @@ -521,19 +534,23 @@ class PluginNode : public Object { /*! \brief Find input ref index for device. */ int FindDeviceRefIdx(const PluginTensor& tensor) const; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("version", &version); - v->Visit("describe", &describe); - v->Visit("attrs", &attrs); - v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); - v->Visit("buffers", &buffers); - v->Visit("externs", &externs); - v->Visit("support_dtypes", &support_dtypes); - v->Visit("options", &options); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &PluginNode::name) + .def_ro("version", &PluginNode::version) + .def_ro("describe", &PluginNode::describe) + .def_ro("attrs", &PluginNode::attrs) + .def_ro("inputs", &PluginNode::inputs) + .def_ro("outputs", &PluginNode::outputs) + .def_ro("buffers", &PluginNode::buffers) + .def_ro("externs", &PluginNode::externs) + .def_ro("support_dtypes", &PluginNode::support_dtypes) + .def_ro("options", &PluginNode::options); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const PluginNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(version, other->version) && equal(describe, other->describe) && equal(attrs, other->attrs) && diff --git a/src/contrib/msc/core/printer/msc_doc.cc b/src/contrib/msc/core/printer/msc_doc.cc index 5497b7f9fe0a..bdbe4a53caa1 100644 --- a/src/contrib/msc/core/printer/msc_doc.cc +++ b/src/contrib/msc/core/printer/msc_doc.cc @@ -86,6 +86,24 @@ LambdaDoc::LambdaDoc(IdDoc name, Array args, Array refs, Arr this->data_ = std::move(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + DeclareDocNode::RegisterReflection(); + StrictListDocNode::RegisterReflection(); + PointerDocNode::RegisterReflection(); + StructDocNode::RegisterReflection(); + ConstructorDocNode::RegisterReflection(); + SwitchDocNode::RegisterReflection(); + LambdaDocNode::RegisterReflection(); +}); + +TVM_REGISTER_NODE_TYPE(DeclareDocNode); +TVM_REGISTER_NODE_TYPE(StrictListDocNode); +TVM_REGISTER_NODE_TYPE(PointerDocNode); +TVM_REGISTER_NODE_TYPE(StructDocNode); +TVM_REGISTER_NODE_TYPE(ConstructorDocNode); +TVM_REGISTER_NODE_TYPE(SwitchDocNode); +TVM_REGISTER_NODE_TYPE(LambdaDocNode); + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h index 3b83d6e22e0a..276e88f7aa4f 100644 --- a/src/contrib/msc/core/printer/msc_doc.h +++ b/src/contrib/msc/core/printer/msc_doc.h @@ -24,6 +24,7 @@ #ifndef TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ #define TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ +#include #include #include @@ -50,15 +51,18 @@ class DeclareDocNode : public ExprDocNode { /*! \brief Whether to use constructor(otherwise initializer) */ bool use_constructor{true}; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("type", &type); - v->Visit("variable", &variable); - v->Visit("init_args", &init_args); - v->Visit("use_constructor", &use_constructor); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("type", &DeclareDocNode::type) + .def_ro("variable", &DeclareDocNode::variable) + .def_ro("init_args", &DeclareDocNode::init_args) + .def_ro("use_constructor", &DeclareDocNode::use_constructor); } - static constexpr const char* _type_key = "script.printer.DeclareDoc"; + static constexpr bool _type_has_method_visit_attrs = false; + + static constexpr const char* _type_key = "msc.script.printer.DeclareDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(DeclareDocNode, ExprDocNode); }; @@ -93,13 +97,16 @@ class StrictListDocNode : public ExprDocNode { /*! \brief Whether to allow empty */ bool allow_empty{true}; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("list", &list); - v->Visit("allow_empty", &allow_empty); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("list", &StrictListDocNode::list) + .def_ro("allow_empty", &StrictListDocNode::allow_empty); } - static constexpr const char* _type_key = "script.printer.StrictListDoc"; + static constexpr bool _type_has_method_visit_attrs = false; + + static constexpr const char* _type_key = "msc.script.printer.StrictListDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(StrictListDocNode, ExprDocNode); }; @@ -129,12 +136,14 @@ class PointerDocNode : public ExprDocNode { /*! \brief The name of the identifier */ String name; - void VisitAttrs(AttrVisitor* v) { - ExprDocNode::VisitAttrs(v); - v->Visit("name", &name); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name", &PointerDocNode::name); } - static constexpr const char* _type_key = "script.printer.PointerDoc"; + static constexpr bool _type_has_method_visit_attrs = false; + + static constexpr const char* _type_key = "msc.script.printer.PointerDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerDocNode, ExprDocNode); }; @@ -167,14 +176,17 @@ class StructDocNode : public StmtDocNode { /*! \brief The body of class. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("decorators", &decorators); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &StructDocNode::name) + .def_ro("decorators", &StructDocNode::decorators) + .def_ro("body", &StructDocNode::body); } - static constexpr const char* _type_key = "script.printer.StructDoc"; + static constexpr bool _type_has_method_visit_attrs = false; + + static constexpr const char* _type_key = "msc.script.printer.StructDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(StructDocNode, StmtDocNode); }; @@ -215,14 +227,17 @@ class ConstructorDocNode : public StmtDocNode { /*! \brief The body of function. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("args", &args); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &ConstructorDocNode::name) + .def_ro("args", &ConstructorDocNode::args) + .def_ro("body", &ConstructorDocNode::body); } - static constexpr const char* _type_key = "script.printer.ConstructorDoc"; + static constexpr bool _type_has_method_visit_attrs = false; + + static constexpr const char* _type_key = "msc.script.printer.ConstructorDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorDocNode, StmtDocNode); }; @@ -257,14 +272,17 @@ class SwitchDocNode : public StmtDocNode { /*! \brief The default_branch of the switch statement. */ Array default_branch; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("predicates", &predicates); - v->Visit("branchs", &branchs); - v->Visit("default_branch", &default_branch); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("predicates", &SwitchDocNode::predicates) + .def_ro("branchs", &SwitchDocNode::branchs) + .def_ro("default_branch", &SwitchDocNode::default_branch); } - static constexpr const char* _type_key = "script.printer.SwitchDoc"; + static constexpr bool _type_has_method_visit_attrs = false; + + static constexpr const char* _type_key = "msc.script.printer.SwitchDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(SwitchDocNode, StmtDocNode); }; @@ -308,15 +326,18 @@ class LambdaDocNode : public StmtDocNode { /*! \brief The body of lambda. */ Array body; - void VisitAttrs(AttrVisitor* v) { - StmtDocNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("args", &args); - v->Visit("refs", &refs); - v->Visit("body", &body); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &LambdaDocNode::name) + .def_ro("args", &LambdaDocNode::args) + .def_ro("refs", &LambdaDocNode::refs) + .def_ro("body", &LambdaDocNode::body); } - static constexpr const char* _type_key = "script.printer.LambdaDoc"; + static constexpr bool _type_has_method_visit_attrs = false; + + static constexpr const char* _type_key = "msc.script.printer.LambdaDoc"; TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, StmtDocNode); }; diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 70197074317d..c75bf0c7361a 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -28,6 +28,12 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ + DiagnosticNode::RegisterReflection(); + DiagnosticRendererNode::RegisterReflection(); + DiagnosticContextNode::RegisterReflection(); +}); + // failed to check to argument arg0.dims[0] != 0 /* Diagnostic */ diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index ce40df21eb9a..e95b44700619 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -26,6 +26,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ EnvFuncNode::RegisterReflection(); }); + using ffi::Any; using ffi::Function; using ffi::PackedArgs; diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f5effa2ba522..3e2f867e0897 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -36,6 +36,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ BaseExprNode::RegisterReflection(); PrimExprNode::RegisterReflection(); RelaxExprNode::RegisterReflection(); + BaseFuncNode::RegisterReflection(); GlobalVarNode::RegisterReflection(); IntImmNode::RegisterReflection(); FloatImmNode::RegisterReflection(); diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 3df9ae00fb53..fb04b53964f1 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -24,6 +24,12 @@ #include namespace tvm { + +TVM_FFI_STATIC_INIT_BLOCK({ + VDeviceNode::RegisterReflection(); + DummyGlobalInfoNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); TVM_FFI_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { auto n = DummyGlobalInfo(make_object()); diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index 1b47c1a89639..901202e81e67 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -30,6 +30,9 @@ #include "tvm/ir/expr.h" namespace tvm { + +TVM_FFI_STATIC_INIT_BLOCK({ GlobalVarSupplyNode::RegisterReflection(); }); + GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, std::unordered_map name_to_var_map) { auto n = make_object(name_supply, name_to_var_map); diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index a273245c1b64..cd52e2b88680 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -32,6 +32,8 @@ namespace tvm { namespace instrument { +TVM_FFI_STATIC_INIT_BLOCK({ PassInstrumentNode::RegisterReflection(); }); + /*! * \brief Base PassInstrument implementation * \sa BasePassInstrument diff --git a/src/ir/module.cc b/src/ir/module.cc index 3166ffba9787..91db645b712a 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -35,6 +35,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ IRModuleNode::RegisterReflection(); }); + IRModule::IRModule(tvm::Map functions, SourceMap source_map, DictAttrs attrs, Map> global_infos) { auto n = make_object(); diff --git a/src/ir/op.cc b/src/ir/op.cc index b6d1f39526db..0442a1038b65 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -33,6 +33,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ OpNode::RegisterReflection(); }); + using ffi::Any; using ffi::Function; using ffi::PackedArgs; diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 482e1dfa1018..95b9e83f77a3 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -28,6 +28,14 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ + SourceNameNode::RegisterReflection(); + SpanNode::RegisterReflection(); + SequentialSpanNode::RegisterReflection(); + SourceNode::RegisterReflection(); + SourceMapObj::RegisterReflection(); +}); + ObjectPtr GetSourceNameNode(const String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 3e8a9aa6ee51..e3a6d886945a 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -348,7 +348,12 @@ class ModulePassNode : public PassNode { ModulePassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("pass_info", &ModulePassNode::pass_info); + } + + static constexpr bool _type_has_method_visit_attrs = false; /*! * \brief Run a module pass on given pass context. @@ -525,6 +530,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +TVM_FFI_STATIC_INIT_BLOCK({ + PassContextNode::RegisterReflection(); + PassInfoNode::RegisterReflection(); + SequentialNode::RegisterReflection(); + ModulePassNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_FFI_REGISTER_GLOBAL("transform.MakeModulePass") diff --git a/src/ir/type.cc b/src/ir/type.cc index 8bc48a11141f..95b65475be9e 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -25,6 +25,13 @@ #include namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ + PrimTypeNode::RegisterReflection(); + PointerTypeNode::RegisterReflection(); + TupleTypeNode::RegisterReflection(); + FuncTypeNode::RegisterReflection(); +}); + PrimType::PrimType(runtime::DataType dtype, Span span) { ObjectPtr n = make_object(); n->dtype = dtype; diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index c81543579655..518da92baf65 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -25,6 +25,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ PrinterConfigNode::RegisterReflection(); }); + TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { static FType inst; return inst; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 843ffcecbe29..116a543f58e1 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -471,10 +471,13 @@ class FieldDependencyFinder : private AttrVisitor { const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); if (tinfo->extra_info != nullptr) { ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - if (auto opt_object = field_value.as()) { - ObjectRef obj = *std::move(opt_object); - this->Visit(field_info->name.data, &obj); + if (field_info->field_static_type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin || + field_info->field_static_type_index == ffi::TypeIndex::kTVMFFIAny) { + Optional index; + ParseOptionalValue(field_info->name.data, &index); + if (index.has_value()) { + jnode_->fields.push_back(*index); + } } }); } else { @@ -643,10 +646,8 @@ class JSONAttrSetter : private AttrVisitor { // dispatch between new reflection and old reflection const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); if (tinfo->extra_info != nullptr) { - ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - this->SetObjectField(obj, field_info); - }); + ffi::reflection::ForEachFieldInfo( + tinfo, [&](const TVMFFIFieldInfo* field_info) { this->SetObjectField(obj, field_info); }); } else { // TODO(tvm-team): remove this once all objects are transitioned to the new reflection reflection_->VisitAttrs(obj, this); diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index efaa7037b013..538cbab837c1 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -699,14 +699,20 @@ TVM_REGISTER_REFLECTION_VTABLE(ffi::MapObj, MapObjTrait) .set_creator([](const std::string&) -> ObjectPtr { return ffi::MapObj::Empty(); }); struct ReportNodeTrait { - static void VisitAttrs(runtime::profiling::ReportNode* report, AttrVisitor* attrs) { - attrs->Visit("calls", &report->calls); - attrs->Visit("device_metrics", &report->device_metrics); - attrs->Visit("configuration", &report->configuration); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("calls", &runtime::profiling::ReportNode::calls) + .def_ro("device_metrics", &runtime::profiling::ReportNode::device_metrics) + .def_ro("configuration", &runtime::profiling::ReportNode::configuration); } + + static constexpr const std::nullptr_t VisitAttrs = nullptr; + static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; +TVM_FFI_STATIC_INIT_BLOCK({ ReportNodeTrait::RegisterReflection(); }); TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::ReportNode, ReportNodeTrait); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -716,51 +722,86 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); struct CountNodeTrait { - static void VisitAttrs(runtime::profiling::CountNode* n, AttrVisitor* attrs) { - attrs->Visit("value", &n->value); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", + &runtime::profiling::CountNode::value); } + + static constexpr const std::nullptr_t VisitAttrs = nullptr; + static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; + +TVM_FFI_STATIC_INIT_BLOCK({ CountNodeTrait::RegisterReflection(); }); + TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::CountNode, CountNodeTrait); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->GetTypeKey() << "(" << op->value << ")"; }); + struct DurationNodeTrait { - static void VisitAttrs(runtime::profiling::DurationNode* n, AttrVisitor* attrs) { - attrs->Visit("microseconds", &n->microseconds); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro( + "microseconds", &runtime::profiling::DurationNode::microseconds); } + + static constexpr const std::nullptr_t VisitAttrs = nullptr; + static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; + +TVM_FFI_STATIC_INIT_BLOCK({ DurationNodeTrait::RegisterReflection(); }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->GetTypeKey() << "(" << op->microseconds << ")"; }); TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::DurationNode, DurationNodeTrait); + struct PercentNodeTrait { - static void VisitAttrs(runtime::profiling::PercentNode* n, AttrVisitor* attrs) { - attrs->Visit("percent", &n->percent); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro( + "percent", &runtime::profiling::PercentNode::percent); } + + static constexpr const std::nullptr_t VisitAttrs = nullptr; + static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; + +TVM_FFI_STATIC_INIT_BLOCK({ PercentNodeTrait::RegisterReflection(); }); + TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::PercentNode, PercentNodeTrait); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << op->GetTypeKey() << "(" << op->percent << ")"; }); + struct RatioNodeTrait { - static void VisitAttrs(runtime::profiling::RatioNode* n, AttrVisitor* attrs) { - attrs->Visit("ratio", &n->ratio); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("ratio", + &runtime::profiling::RatioNode::ratio); } + + static constexpr const std::nullptr_t VisitAttrs = nullptr; + static constexpr std::nullptr_t SEqualReduce = nullptr; static constexpr std::nullptr_t SHashReduce = nullptr; }; + +TVM_FFI_STATIC_INIT_BLOCK({ RatioNodeTrait::RegisterReflection(); }); + TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::RatioNode, RatioNodeTrait); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 8dee9b24f493..a3b0e8460ce7 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -21,6 +21,7 @@ * \file src/relax/backend/contrib/cutlass/codegen.cc * \brief Implementation of the CUTLASS code generator for Relax. */ +#include #include #include #include @@ -79,10 +80,15 @@ class CodegenResultNode : public Object { String code; Array headers; - void VisitAttrs(AttrVisitor* v) { - v->Visit("code", &code); - v->Visit("headers", &headers); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("code", &CodegenResultNode::code) + .def_ro("headers", &CodegenResultNode::headers); } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "contrib.cutlass.CodegenResult"; TVM_DECLARE_FINAL_OBJECT_INFO(CodegenResultNode, Object); }; @@ -99,6 +105,8 @@ class CodegenResult : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(CodegenResult, ObjectRef, CodegenResultNode); }; +TVM_FFI_STATIC_INIT_BLOCK({ CodegenResultNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(CodegenResultNode); TVM_FFI_REGISTER_GLOBAL("contrib.cutlass.CodegenResult") diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index 56b035212e3f..b13f4da6dae0 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -29,6 +29,8 @@ namespace relax { using namespace vm; +TVM_FFI_STATIC_INIT_BLOCK({ ExecBuilderNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(ExecBuilderNode); ExecBuilder ExecBuilderNode::Create() { diff --git a/src/relax/distributed/global_info.cc b/src/relax/distributed/global_info.cc index e1cc32fc82e3..a20c25102734 100644 --- a/src/relax/distributed/global_info.cc +++ b/src/relax/distributed/global_info.cc @@ -17,12 +17,15 @@ * under the License. */ +#include #include namespace tvm { namespace relax { namespace distributed { +TVM_FFI_STATIC_INIT_BLOCK({ DeviceMeshNode::RegisterReflection(); }); + DeviceMesh::DeviceMesh(ffi::Shape shape, Array device_ids) { int prod = 1; for (int i = 0; i < static_cast(shape.size()); i++) { diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 0ff9d4d6fa09..93c5d75b5de1 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -27,6 +27,12 @@ namespace tvm { namespace relax { namespace distributed { +TVM_FFI_STATIC_INIT_BLOCK({ + DTensorStructInfoNode::RegisterReflection(); + PlacementNode::RegisterReflection(); + PlacementSpecNode::RegisterReflection(); +}); + PlacementSpec PlacementSpec::Sharding(int axis) { ObjectPtr n = make_object(); n->axis = axis; diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 11a0fd29a92f..e6dd082f9a24 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -22,6 +22,7 @@ * \brief Implementation of binding rewriters. */ +#include #include #include #include @@ -35,6 +36,8 @@ namespace tvm { namespace relax { +TVM_FFI_STATIC_INIT_BLOCK({ DataflowBlockRewriteNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode); DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 172f4d7bcb27..284d7eaf0c42 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -374,12 +375,15 @@ class PatternContextRewriterNode : public PatternMatchingRewriterNode { RewriteSpec RewriteBindings(const Array& bindings) const override; - void VisitAttrs(AttrVisitor* visitor) { - visitor->Visit("pattern", &pattern); - ffi::Function untyped_func = rewriter_func; - visitor->Visit("rewriter_func", &untyped_func); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pattern", &PatternContextRewriterNode::pattern) + .def_ro("rewriter_func", &PatternContextRewriterNode::rewriter_func); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.PatternContextRewriter"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, PatternMatchingRewriterNode); @@ -449,5 +453,7 @@ Function RewriteBindings( TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); +TVM_FFI_STATIC_INIT_BLOCK({ PatternContextRewriterNode::RegisterReflection(); }); + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index c398305d938c..c105180e31a4 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -22,6 +22,7 @@ * \brief A transform to match a Relax Expr and rewrite */ +#include #include #include #include @@ -661,7 +662,6 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { Function func_replacement = [&]() { CHECK(mod->ContainGlobalVar("replacement")) << "KeyError: " - << "Expected module to contain 'replacement', " << "a Relax function defining the replacement to be matched, " << "but the module did not contain a 'replacement' function."; @@ -1075,5 +1075,12 @@ Function RewriteCall(const DFPattern& pat, TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); +TVM_FFI_STATIC_INIT_BLOCK({ + PatternMatchingRewriterNode::RegisterReflection(); + ExprPatternRewriterNode::RegisterReflection(); + OrRewriterNode::RegisterReflection(); + TupleRewriterNode::RegisterReflection(); +}); + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index f322455c81e0..48332de25f3a 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -28,6 +28,33 @@ #include #include +namespace tvm { +namespace relax { + +TVM_FFI_STATIC_INIT_BLOCK({ + PatternSeqNode::RegisterReflection(); + ExprPatternNode::RegisterReflection(); + VarPatternNode::RegisterReflection(); + DataflowVarPatternNode::RegisterReflection(); + CallPatternNode::RegisterReflection(); + PrimArrPatternNode::RegisterReflection(); + FunctionPatternNode::RegisterReflection(); + TuplePatternNode::RegisterReflection(); + UnorderedTuplePatternNode::RegisterReflection(); + TupleGetItemPatternNode::RegisterReflection(); + AndPatternNode::RegisterReflection(); + OrPatternNode::RegisterReflection(); + NotPatternNode::RegisterReflection(); + WildcardPatternNode::RegisterReflection(); + StructInfoPatternNode::RegisterReflection(); + ShapePatternNode::RegisterReflection(); + SameShapeConstraintNode::RegisterReflection(); + DataTypePatternNode::RegisterReflection(); + AttrPatternNode::RegisterReflection(); + ExternFuncPatternNode::RegisterReflection(); + ConstantPatternNode::RegisterReflection(); +}); + #define RELAX_PATTERN_PRINTER_DEF(NODE_TYPE, REPR_LAMBDA) \ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { \ @@ -35,9 +62,6 @@ REPR_LAMBDA(p, node); \ }) -namespace tvm { -namespace relax { - TVM_REGISTER_NODE_TYPE(ExternFuncPatternNode); ExternFuncPattern::ExternFuncPattern(String global_symbol) { ObjectPtr n = make_object(); diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h index d2016adbf8e7..87e855fbd937 100644 --- a/src/relax/ir/dataflow_rewriter.h +++ b/src/relax/ir/dataflow_rewriter.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAX_IR_DATAFLOW_REWRITER_H_ #define TVM_RELAX_IR_DATAFLOW_REWRITER_H_ +#include #include #include #include @@ -54,7 +55,11 @@ class PatternMatchingRewriterNode : public tvm::transform::PassNode { return RewriteSpec(); } - void VisitAttrs(AttrVisitor* visitor) {} + static void RegisterReflection() { + // PatternMatchingRewriterNode has no fields to register + } + + static constexpr bool _type_has_method_visit_attrs = false; IRModule operator()(IRModule mod, const tvm::transform::PassContext& pass_ctx) const override; tvm::transform::PassInfo Info() const override; @@ -89,12 +94,15 @@ class ExprPatternRewriterNode : public PatternMatchingRewriterNode { Optional RewriteExpr(const Expr& expr, const Map& bindings) const; - void VisitAttrs(AttrVisitor* visitor) { - visitor->Visit("pattern", &pattern); - ffi::Function untyped_func = func; - visitor->Visit("func", &untyped_func); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("pattern", &ExprPatternRewriterNode::pattern) + .def_ro("func", &ExprPatternRewriterNode::func); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.ExprPatternRewriter"; TVM_DECLARE_BASE_OBJECT_INFO(ExprPatternRewriterNode, PatternMatchingRewriterNode); }; @@ -117,11 +125,15 @@ class OrRewriterNode : public PatternMatchingRewriterNode { RewriteSpec RewriteBindings(const Array& bindings) const override; - void VisitAttrs(AttrVisitor* visitor) { - visitor->Visit("lhs", &lhs); - visitor->Visit("rhs", &rhs); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("lhs", &OrRewriterNode::lhs) + .def_ro("rhs", &OrRewriterNode::rhs); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.OrRewriter"; TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, PatternMatchingRewriterNode); }; @@ -142,12 +154,15 @@ class TupleRewriterNode : public PatternMatchingRewriterNode { RewriteSpec RewriteBindings(const Array& bindings) const override; - void VisitAttrs(AttrVisitor* visitor) { - visitor->Visit("patterns", &patterns); - ffi::Function untyped_func = func; - visitor->Visit("func", &untyped_func); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("patterns", &TupleRewriterNode::patterns) + .def_ro("func", &TupleRewriterNode::func); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.dpl.TupleRewriter"; TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, PatternMatchingRewriterNode); diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index e75dc3c2d7ca..518ca7c0488c 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -35,6 +35,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "rxplaceholder(" << op->name << ", " << op << ")"; }); +TVM_FFI_STATIC_INIT_BLOCK({ RXPlaceholderOpNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode); te::Tensor TETensor(Expr value, Map tir_var_map, std::string name) { diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index 46207479c7ef..b5e45bde7474 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAX_IR_EMIT_TE_H_ #define TVM_RELAX_IR_EMIT_TE_H_ +#include #include #include @@ -40,15 +41,19 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { /*! \brief The relax expression. */ Expr value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("tag", &tag); - v->Visit("attrs", &attrs); - v->Visit("value", &value); - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &RXPlaceholderOpNode::name) + .def_ro("tag", &RXPlaceholderOpNode::tag) + .def_ro("attrs", &RXPlaceholderOpNode::attrs) + .def_ro("value", &RXPlaceholderOpNode::value) + .def_ro("shape", &RXPlaceholderOpNode::shape) + .def_ro("dtype", &RXPlaceholderOpNode::dtype); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "RXPlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); }; diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index c9d83e92389e..4db18817e154 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -28,6 +29,28 @@ namespace relax { using tvm::ReprPrinter; +TVM_FFI_STATIC_INIT_BLOCK({ + IdNode::RegisterReflection(); + CallNode::RegisterReflection(); + TupleNode::RegisterReflection(); + TupleGetItemNode::RegisterReflection(); + ShapeExprNode::RegisterReflection(); + VarNode::RegisterReflection(); + BindingNode::RegisterReflection(); + DataflowVarNode::RegisterReflection(); + ConstantNode::RegisterReflection(); + PrimValueNode::RegisterReflection(); + StringImmNode::RegisterReflection(); + DataTypeImmNode::RegisterReflection(); + MatchCastNode::RegisterReflection(); + VarBindingNode::RegisterReflection(); + BindingBlockNode::RegisterReflection(); + SeqExprNode::RegisterReflection(); + IfNode::RegisterReflection(); + FunctionNode::RegisterReflection(); + ExternFuncNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(IdNode); Id::Id(String name_hint) { diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index dc355cef905f..4fc585b37c18 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -21,6 +21,7 @@ * \file src/relax/py_expr_functor.cc * \brief The backbone of PyExprVisitor/PyExprMutator. */ +#include #include namespace tvm { @@ -136,7 +137,12 @@ class PyExprVisitorNode : public Object, public ExprVisitor { void VisitSpan(const Span& span) PY_EXPR_VISITOR_DEFAULT(span, f_visit_span, ExprVisitor::VisitSpan(span)); - void VisitAttrs(AttrVisitor* v) {} + static void RegisterReflection() { + // PyExprVisitorNode has no fields to register + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "expr_functor.PyExprVisitor"; TVM_DECLARE_BASE_OBJECT_INFO(PyExprVisitorNode, Object); @@ -393,7 +399,13 @@ class PyExprMutatorNode : public Object, public ExprMutator { using ExprMutator::VisitWithNewScope; using ExprMutator::WithStructInfo; - void VisitAttrs(AttrVisitor* v) { v->Visit("builder_", &builder_); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("builder_", &PyExprMutatorNode::builder_); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "expr_functor.PyExprMutator"; TVM_DECLARE_BASE_OBJECT_INFO(PyExprMutatorNode, Object); @@ -689,5 +701,10 @@ TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") .set_body_typed([](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); +TVM_FFI_STATIC_INIT_BLOCK({ + PyExprVisitorNode::RegisterReflection(); + PyExprMutatorNode::RegisterReflection(); +}); + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 9da3de96b325..8599fc52e16b 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -22,6 +22,7 @@ * \brief Relax struct info. */ #include +#include #include #include #include @@ -29,6 +30,15 @@ namespace tvm { namespace relax { +TVM_FFI_STATIC_INIT_BLOCK({ + ObjectStructInfoNode::RegisterReflection(); + PrimStructInfoNode::RegisterReflection(); + ShapeStructInfoNode::RegisterReflection(); + TensorStructInfoNode::RegisterReflection(); + TupleStructInfoNode::RegisterReflection(); + FuncStructInfoNode::RegisterReflection(); +}); + ObjectStructInfo::ObjectStructInfo(Span span) { ObjectPtr n = make_object(); n->span = span; diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc index cbe4170bb979..4138dfbcb39f 100644 --- a/src/relax/ir/tir_pattern.cc +++ b/src/relax/ir/tir_pattern.cc @@ -22,6 +22,8 @@ namespace tvm { namespace relax { +TVM_FFI_STATIC_INIT_BLOCK({ MatchResultNode::RegisterReflection(); }); + MatchResult::MatchResult(TIRPattern pattern, Array symbol_values, Array matched_buffers) { auto n = make_object(); diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index a44deba0fe94..a63671a0a2d0 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -61,7 +62,12 @@ class FunctionPassNode : public tvm::transform::PassNode { FunctionPassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("pass_info", &FunctionPassNode::pass_info); + } + + static constexpr bool _type_has_method_visit_attrs = false; /*! * \brief Run a function pass on given pass context. @@ -205,7 +211,12 @@ class DataflowBlockPassNode : public tvm::transform::PassNode { DataflowBlockPassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("pass_info", &DataflowBlockPassNode::pass_info); + } + + static constexpr bool _type_has_method_visit_attrs = false; IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; @@ -401,6 +412,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Run DataflowBlock pass: " << info->name << " at the optimization level " << info->opt_level; }); + +TVM_FFI_STATIC_INIT_BLOCK({ + FunctionPassNode::RegisterReflection(); + DataflowBlockPassNode::RegisterReflection(); +}); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 8b70bcf2c7a5..8a8aa460e80b 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -22,11 +22,19 @@ * \brief Relax type system. */ #include +#include #include namespace tvm { namespace relax { +TVM_FFI_STATIC_INIT_BLOCK({ + ShapeTypeNode::RegisterReflection(); + TensorTypeNode::RegisterReflection(); + ObjectTypeNode::RegisterReflection(); + PackedFuncTypeNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(ShapeTypeNode); ShapeType::ShapeType(int ndim, Span span) { diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 51ab6bb23068..723118411bbf 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -23,6 +23,7 @@ * into in-place versions. */ +#include #include #include #include @@ -524,15 +525,21 @@ class InplaceOpportunityNode : public Object { Integer binding_idx; Array arg_idxs; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("binding_idx", &binding_idx); - v->Visit("arg_idxs", &arg_idxs); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("binding_idx", &InplaceOpportunityNode::binding_idx) + .def_ro("arg_idxs", &InplaceOpportunityNode::arg_idxs); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "relax.transform.InplaceOpportunity"; TVM_DECLARE_BASE_OBJECT_INFO(InplaceOpportunityNode, Object); }; +TVM_FFI_STATIC_INIT_BLOCK({ InplaceOpportunityNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(InplaceOpportunityNode); class InplaceOpportunity : public ObjectRef { diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index f9ffcd930283..801dea14856d 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -27,6 +27,7 @@ * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. */ +#include #include #include #include @@ -47,6 +48,11 @@ namespace tvm { namespace relax { +TVM_FFI_STATIC_INIT_BLOCK({ + transform::FusionPatternNode::RegisterReflection(); + transform::PatternCheckContextNode::RegisterReflection(); +}); + /* Note on Fusing algorithm: diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index aca048820996..4bba41d6571b 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ - #include "infer_layout_utils.h" +#include + #include "utils.h" namespace tvm { @@ -155,5 +156,10 @@ LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim) { } } +TVM_FFI_STATIC_INIT_BLOCK({ + LayoutDecisionNode::RegisterReflection(); + InferLayoutOutputNode::RegisterReflection(); +}); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index d8666cc431da..f8675b70e8f4 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -27,6 +27,7 @@ #ifndef TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_ #define TVM_RELAX_TRANSFORM_INFER_LAYOUT_UTILS_H_ +#include #include #include #include @@ -61,7 +62,14 @@ class LayoutDecisionNode : public Object { /*! \brief Whether the dim of tensor is unknown. */ bool is_unknown_dim = false; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("layout", &layout); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("layout", &LayoutDecisionNode::layout) + .def_ro("is_unknown_dim", &LayoutDecisionNode::is_unknown_dim); + } + + static constexpr bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(LayoutDecisionNode, Object); @@ -104,12 +112,17 @@ class InferLayoutOutputNode : public Object { Attrs new_attrs; Map new_args; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("input_layouts", &input_layouts); - v->Visit("output_layouts", &output_layouts); - v->Visit("new_attrs", &new_attrs); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("input_layouts", &InferLayoutOutputNode::input_layouts) + .def_ro("output_layouts", &InferLayoutOutputNode::output_layouts) + .def_ro("new_attrs", &InferLayoutOutputNode::new_attrs) + .def_ro("new_args", &InferLayoutOutputNode::new_args); } + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DECLARE_BASE_OBJECT_INFO(InferLayoutOutputNode, Object); static constexpr const char* _type_key = "relax.transform.InferLayoutOutput"; diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 8f1fd77d782d..1665b9d88b12 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -25,6 +25,36 @@ namespace tvm { namespace script { namespace printer { +TVM_FFI_STATIC_INIT_BLOCK({ + DocNode::RegisterReflection(); + ExprDocNode::RegisterReflection(); + StmtDocNode::RegisterReflection(); + StmtBlockDocNode::RegisterReflection(); + LiteralDocNode::RegisterReflection(); + IdDocNode::RegisterReflection(); + AttrAccessDocNode::RegisterReflection(); + IndexDocNode::RegisterReflection(); + CallDocNode::RegisterReflection(); + OperationDocNode::RegisterReflection(); + LambdaDocNode::RegisterReflection(); + TupleDocNode::RegisterReflection(); + ListDocNode::RegisterReflection(); + DictDocNode::RegisterReflection(); + SliceDocNode::RegisterReflection(); + AssignDocNode::RegisterReflection(); + IfDocNode::RegisterReflection(); + WhileDocNode::RegisterReflection(); + ForDocNode::RegisterReflection(); + ScopeDocNode::RegisterReflection(); + ExprStmtDocNode::RegisterReflection(); + AssertDocNode::RegisterReflection(); + ReturnDocNode::RegisterReflection(); + FunctionDocNode::RegisterReflection(); + ClassDocNode::RegisterReflection(); + CommentDocNode::RegisterReflection(); + DocStringDocNode::RegisterReflection(); +}); + ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } ExprDoc ExprDocNode::operator[](Array indices) const { diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index c8f029d225a8..b6f7a1fc3546 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -24,6 +24,8 @@ namespace tvm { namespace script { namespace printer { +TVM_FFI_STATIC_INIT_BLOCK({ IRFrameNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(IRFrameNode); struct SortableFunction { diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index 88e0113e2840..4d5f711bad9c 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_IR_UTILS_H_ #define TVM_SCRIPT_PRINTER_IR_UTILS_H_ +#include #include #include #include @@ -37,11 +38,14 @@ namespace printer { class IRFrameNode : public FrameNode { public: Map>* global_infos = nullptr; - void VisitAttrs(AttrVisitor* v) { - FrameNode::VisitAttrs(v); - // `global_infos` is not visited + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + // global infos is not exposed } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.printer.IRFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(IRFrameNode, FrameNode); }; diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 6b58d1e03a2a..173eb58a306b 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -30,6 +30,11 @@ namespace tvm { namespace script { namespace printer { +TVM_FFI_STATIC_INIT_BLOCK({ + FrameNode::RegisterReflection(); + IRDocsifierNode::RegisterReflection(); +}); + IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { if (auto it = obj2info.find(obj); it != obj2info.end()) { // TVM's IR dialects do not allow multiple definitions of the same diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 99e30ab520a5..cba6f88ff88b 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -37,6 +37,8 @@ bool AtTopLevelFunction(const IRDocsifier& d) { return d->frames.size() == 3; } +TVM_FFI_STATIC_INIT_BLOCK({ RelaxFrameNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(RelaxFrameNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index e28fd9c8036b..f495bcba887d 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ #define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_ +#include #include #include #include @@ -43,12 +44,15 @@ class RelaxFrameNode : public FrameNode { bool module_alias_printed = false; std::unordered_set* func_vars = nullptr; - void VisitAttrs(AttrVisitor* v) { - FrameNode::VisitAttrs(v); - v->Visit("is_global_func", &is_func); - // `func_var_to_define` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("is_func", &RelaxFrameNode::is_func) + .def_ro("module_alias_printed", &RelaxFrameNode::module_alias_printed); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.printer.RelaxFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(RelaxFrameNode, FrameNode); }; diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index bba1686f920c..b98e4545b4ce 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -24,6 +24,8 @@ namespace tvm { namespace script { namespace printer { +TVM_FFI_STATIC_INIT_BLOCK({ TIRFrameNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index d1bc56d13960..6b37ca955078 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -19,6 +19,7 @@ #ifndef TVM_SCRIPT_PRINTER_TIR_UTILS_H_ #define TVM_SCRIPT_PRINTER_TIR_UTILS_H_ +#include #include #include #include @@ -48,12 +49,15 @@ class TIRFrameNode : public FrameNode { /*! \brief Whether or not the frame allows concise scoping */ bool allow_concise_scoping{false}; - void VisitAttrs(AttrVisitor* v) { - FrameNode::VisitAttrs(v); - v->Visit("tir", &tir); - v->Visit("allow_concise_scoping", &allow_concise_scoping); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("tir", &TIRFrameNode::tir) + .def_ro("allow_concise_scoping", &TIRFrameNode::allow_concise_scoping); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.printer.TIRFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(TIRFrameNode, FrameNode); }; diff --git a/src/target/tag.cc b/src/target/tag.cc index 0df0d8d2c7af..04f0a146034f 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -31,6 +31,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ TargetTagNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(TargetTagNode); TVM_FFI_REGISTER_GLOBAL("target.TargetTagListTags").set_body_typed(TargetTag::ListTags); diff --git a/src/target/target.cc b/src/target/target.cc index d9e3f9b51ee7..c73918b9d125 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -42,6 +42,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ TargetNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(TargetNode); class TargetInternal { diff --git a/src/target/target_info.cc b/src/target/target_info.cc index 6e673905d3c2..4d6624e0a4b5 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -26,6 +26,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ MemoryInfoNode::RegisterReflection(); }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index cdec2ede0643..ae35c7d97f12 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -35,6 +35,8 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ TargetKindNode::RegisterReflection(); }); + // helper to get internal dev function in objectref. struct TargetKind2ObjectPtr : public ObjectRef { static ObjectPtr Get(const TargetKind& kind) { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 294b34bf5d2e..c626fe6aa17b 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -38,6 +38,12 @@ namespace tvm { namespace te { using namespace tir; +TVM_FFI_STATIC_INIT_BLOCK({ + OperationNode::RegisterReflection(); + BaseComputeOpNode::RegisterReflection(); + ComputeOpNode::RegisterReflection(); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 9f8531998e88..c7283d847691 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -29,6 +29,9 @@ namespace tvm { namespace te { using namespace tir; + +TVM_FFI_STATIC_INIT_BLOCK({ ExternOpNode::RegisterReflection(); }); + // ExternOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index cce70420c0bd..2e826c836ed6 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -28,6 +28,8 @@ namespace tvm { namespace te { +TVM_FFI_STATIC_INIT_BLOCK({ PlaceholderOpNode::RegisterReflection(); }); + // PlaceholderOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index f4860cf71ef7..c4d56a7f15c2 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -29,6 +29,8 @@ namespace tvm { namespace te { using namespace tir; +TVM_FFI_STATIC_INIT_BLOCK({ ScanOpNode::RegisterReflection(); }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/te/tensor.cc b/src/te/tensor.cc index a23f4b494ece..cb3cb593d751 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -27,6 +27,17 @@ namespace tvm { namespace te { +void TensorNode::RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("shape", &TensorNode::shape) + .def_ro("dtype", &TensorNode::dtype) + .def_ro("op", &TensorNode::op) + .def_ro("value_index", &TensorNode::value_index); +} + +TVM_FFI_STATIC_INIT_BLOCK({ TensorNode::RegisterReflection(); }); + IterVar thread_axis(Range dom, std::string tag) { return IterVar(dom, Var(tag, dom.defined() ? dom->extent.dtype() : DataType::Int(32)), kThreadIndex, tag); diff --git a/src/tir/ir/block_dependence_info.cc b/src/tir/ir/block_dependence_info.cc index dc1e3c48c924..9267ecec75a3 100644 --- a/src/tir/ir/block_dependence_info.cc +++ b/src/tir/ir/block_dependence_info.cc @@ -23,6 +23,8 @@ namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ BlockDependenceInfoNode::RegisterReflection(); }); + /** * @brief A helper class to collect and build Block Dependences using BlockScope class */ diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc index 381fae73a475..70d35aa9e259 100644 --- a/src/tir/ir/block_scope.cc +++ b/src/tir/ir/block_scope.cc @@ -16,12 +16,19 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ + StmtSRefNode::RegisterReflection(); + DependencyNode::RegisterReflection(); + BlockScopeNode::RegisterReflection(); +}); + /******** Utility functions ********/ template diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index bce9c2c4e1a8..8fcd909bb2fc 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -37,6 +37,8 @@ namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ BufferNode::RegisterReflection(); }); + using IndexMod = tir::FloorModNode; using IndexDiv = tir::FloorDivNode; diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 96f87344cbea..fbb1901f087e 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -34,6 +34,11 @@ using tir::IterVar; using tir::IterVarNode; using tir::Var; +TVM_FFI_STATIC_INIT_BLOCK({ + LayoutNode::RegisterReflection(); + BijectiveLayoutNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 9eb9bc26e343..3996435e8e84 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -22,6 +22,7 @@ * \brief The function data structure. */ #include +#include #include #include #include @@ -29,6 +30,12 @@ namespace tvm { namespace tir { + +TVM_FFI_STATIC_INIT_BLOCK({ + PrimFuncNode::RegisterReflection(); + TensorIntrinNode::RegisterReflection(); +}); + namespace { relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { Array params; diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 7297b62bf36d..1596be567fc9 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,8 @@ namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ IndexMapNode::RegisterReflection(); }); + IndexMap::IndexMap(Array initial_indices, Array final_indices, Optional inverse_index_map) { auto n = make_object(); diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 6152c99eaf7d..9e65365719b2 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -23,6 +23,7 @@ * StmtExprVisitor/StmtExprMutator. */ +#include #include #include @@ -213,7 +214,12 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { vtable(stmt, this); } - void VisitAttrs(AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register as they are not visited + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.PyStmtExprVisitor"; TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprVisitorNode, Object); @@ -572,7 +578,13 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { static FStmtType vtable = InitStmtVTable(); vtable(stmt, this); } - void VisitAttrs(AttrVisitor* v) {} + + static void RegisterReflection() { + // No fields to register as they are not visited + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.PyStmtExprMutator"; TVM_DECLARE_BASE_OBJECT_INFO(PyStmtExprMutatorNode, Object); @@ -628,7 +640,6 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { PY_EXPR_MUTATOR_DISPATCH(FloatImmNode, f_visit_float_imm); PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm); - private: private: static FExprType InitExprVTable() { FExprType vtable; @@ -813,6 +824,11 @@ class PyStmtExprMutator : public ObjectRef { // TVM Register // ================================================ +TVM_FFI_STATIC_INIT_BLOCK({ + PyStmtExprVisitorNode::RegisterReflection(); + PyStmtExprMutatorNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(PyStmtExprVisitorNode); TVM_REGISTER_NODE_TYPE(PyStmtExprMutatorNode); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 6a5e1191d219..05684c1bd8ce 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -22,6 +22,7 @@ * \brief TIR specific transformation passes. */ #include +#include #include #include #include @@ -62,7 +63,12 @@ class PrimFuncPassNode : public PassNode { /*! \brief The pass function called on each. */ std::function pass_func; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("pass_info", &PrimFuncPassNode::pass_info); + } + + static constexpr bool _type_has_method_visit_attrs = false; /*! * \brief Run a function pass on given pass context. @@ -142,6 +148,8 @@ Pass CreatePrimFuncPass(std::function return PrimFuncPass(std::move(pass_func), pass_info); } +TVM_FFI_STATIC_INIT_BLOCK({ PrimFuncPassNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); TVM_FFI_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ad890ecb404e..277b89628331 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -741,12 +741,16 @@ class TensorizeInfoNode : public Object { */ Optional> block_iter_paddings; - void VisitAttrs(AttrVisitor* v) { - v->Visit("loop_map", &loop_map); - v->Visit("desc_loop_indexer", &desc_loop_indexer); - v->Visit("block_iter_paddings", &block_iter_paddings); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("loop_map", &TensorizeInfoNode::loop_map) + .def_ro("desc_loop_indexer", &TensorizeInfoNode::desc_loop_indexer) + .def_ro("block_iter_paddings", &TensorizeInfoNode::block_iter_paddings); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.schedule.TensorizeInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); }; @@ -785,14 +789,18 @@ class AutoTensorizeMappingInfoNode : public Object { /*! \brief Block iters on RHS */ Array rhs_iters; - void VisitAttrs(AttrVisitor* v) { - v->Visit("mappings", &mappings); - v->Visit("lhs_buffer_map", &lhs_buffer_map); - v->Visit("rhs_buffer_indices", &rhs_buffer_indices); - v->Visit("lhs_iters", &lhs_iters); - v->Visit("rhs_iters", &rhs_iters); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("mappings", &AutoTensorizeMappingInfoNode::mappings) + .def_ro("lhs_buffer_map", &AutoTensorizeMappingInfoNode::lhs_buffer_map) + .def_ro("rhs_buffer_indices", &AutoTensorizeMappingInfoNode::rhs_buffer_indices) + .def_ro("lhs_iters", &AutoTensorizeMappingInfoNode::lhs_iters) + .def_ro("rhs_iters", &AutoTensorizeMappingInfoNode::rhs_iters); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object); }; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 99f4050a84e5..9d23661bace3 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,12 +16,19 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../ir_comparator.h" #include "../utils.h" namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ + TensorizeInfoNode::RegisterReflection(); + AutoTensorizeMappingInfoNode::RegisterReflection(); +}); + /******** IR Module ********/ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index b00d2069ed17..70cb57dc6423 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -50,15 +50,12 @@ class ConcreteScheduleNode : public ScheduleNode { support::LinearCongruentialEngine::TRandState rand_state_; public: - void VisitAttrs(tvm::AttrVisitor* v) { - // `state_` is not visited - // `func_working_on_` is not visited - // `error_render_level_` is not visited - // `symbol_table_` is not visited - // `analyzer_` is not visited - // `rand_state_` is not visited + static void RegisterReflection() { + // No fields to register as they are not visited } + static constexpr bool _type_has_method_visit_attrs = false; + virtual ~ConcreteScheduleNode() = default; public: diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 7fd43c9242f0..7851c697a144 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -21,6 +21,11 @@ namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ + InstructionKindNode::RegisterReflection(); + InstructionNode::RegisterReflection(); +}); + bool InstructionKindNode::IsPostproc() const { static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); return this == inst_enter_postproc.get(); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 8dc1dcf8dbb2..7ac47e136983 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -16,10 +16,17 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ + BlockRVNode::RegisterReflection(); + LoopRVNode::RegisterReflection(); +}); + /**************** Constructor ****************/ BlockRV::BlockRV() { this->data_ = make_object(); } diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index f2c4b56121c9..7bda23f1df70 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -22,6 +22,8 @@ namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ ScheduleStateNode::RegisterReflection(); }); + template using SMap = std::unordered_map; diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 1992f5ae8a69..574ee3aab625 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -16,11 +16,15 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ TraceNode::RegisterReflection(); }); + /**************** Constructors ****************/ Trace::Trace() { data_ = make_object(); } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 777f31a57bea..50586d70e6de 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -31,14 +31,12 @@ class TracedScheduleNode : public ConcreteScheduleNode { Trace trace_; public: - void VisitAttrs(tvm::AttrVisitor* v) { - // `state_` is not visited - // `error_render_level_` is not visited - // `symbol_table_` is not visited - // `analyzer_` is not visitied - // `trace_` is not visited + static void RegisterReflection() { + // No fields to register as they are not visited } + static constexpr bool _type_has_method_visit_attrs = false; + ~TracedScheduleNode() = default; public: