diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index d53a4817ad0b..04d96857cb9a 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -404,6 +404,8 @@ inline Function GetMethod(std::string_view type_key, const char* method_name) { */ template inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) { + using ResultType = decltype(callback(type_info->fields)); + static_assert(std::is_same_v, "Callback must return void"); // iterate through acenstors in parent to child order // skip the first one since it is always the root object for (int i = 1; i < type_info->type_depth; ++i) { @@ -417,6 +419,34 @@ inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) { } } +/*! + * \brief Visit each field info of the type info and run callback which returns bool for early stop. + * + * \tparam Callback The callback function type, which returns bool for early stop. + * + * \param type_info The type info. + * \param callback_with_early_stop The callback function. + * \return true if any of early stop is triggered. + * + * \note This function calls both the child and parent type info and can be used for searching. + */ +template +inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info, + Callback callback_with_early_stop) { + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + for (int i = 1; i < type_info->type_depth; ++i) { + const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i]; + for (int j = 0; j < parent_info->num_fields; ++j) { + if (callback_with_early_stop(parent_info->fields + j)) return true; + } + } + for (int i = 0; i < type_info->num_fields; ++i) { + if (callback_with_early_stop(type_info->fields + i)) return true; + } + return false; +} + } // namespace reflection } // namespace ffi } // namespace tvm diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 6378d6f74ac2..a40982251253 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -47,6 +47,7 @@ #include #include #include +#include #include #include #include @@ -970,5 +971,65 @@ inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(* } } +/*! + * \brief Adapter for AttrsNode with the new reflection API. + * + * We will phaseout the old AttrsNode in future in favor of the new reflection API. + * This adapter allows us to gradually migrate to the new reflection API. + * + * \tparam DerivedType The final attribute type. + */ +template +class AttrsNodeReflAdapter : public BaseAttrsNode { + public: + void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final { + LOG(FATAL) << "`" << DerivedType::_type_key << "` uses new reflection mechanism for init"; + } + void VisitNonDefaultAttrs(AttrVisitor* v) final { + LOG(FATAL) << "`" << DerivedType::_type_key + << "` uses new reflection mechanism for visit non default attrs"; + } + void VisitAttrs(AttrVisitor* v) final { + LOG(FATAL) << "`" << DerivedType::_type_key + << "` uses new reflection mechanism for visit attrs"; + } + + bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex()); + bool success = true; + ffi::reflection::ForEachFieldInfoWithEarlyStop( + type_info, [&](const TVMFFIFieldInfo* field_info) { + ffi::reflection::FieldGetter field_getter(field_info); + ffi::Any field_value = field_getter(self()); + ffi::Any other_field_value = field_getter(other); + if (!equal.AnyEqual(field_value, other_field_value)) { + success = false; + return true; + } + return false; + }); + return success; + } + + void SHashReduce(SHashReducer hash_reducer) const { + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex()); + ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + ffi::reflection::FieldGetter field_getter(field_info); + ffi::Any field_value = field_getter(self()); + hash_reducer(field_value); + }); + } + + Array ListFieldInfo() const final { + // use the new reflection to list field info + return Array(); + } + + private: + DerivedType* self() const { + return const_cast(static_cast(this)); + } +}; + } // namespace tvm #endif // TVM_IR_ATTRS_H_ diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h index de043f92be82..ffc8301e9f28 100644 --- a/include/tvm/relax/attrs/ccl.h +++ b/include/tvm/relax/attrs/ccl.h @@ -24,54 +24,70 @@ #ifndef TVM_RELAX_ATTRS_CCL_H_ #define TVM_RELAX_ATTRS_CCL_H_ +#include #include namespace tvm { namespace relax { /*! \brief Attributes used in allreduce operators */ -struct AllReduceAttrs : public tvm::AttrsNode { +struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter { String op_type; bool in_group; - TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") { - TVM_ATTR_FIELD(op_type).describe( - "The type of reduction operation to be applied to the input data. Now only sum is " - "supported."); - TVM_ATTR_FIELD(in_group).describe( - "Whether the reduction operation performs in group or globally or in group as default."); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("op_type", &AllReduceAttrs::op_type, + "The type of reduction operation to be applied to the input data. Now only sum is " + "supported.") + .def_ro("in_group", &AllReduceAttrs::in_group, + "Whether the reduction operation performs in group or globally or in group as " + "default."); } + + static constexpr const char* _type_key = "relax.attrs.AllReduceAttrs"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllReduceAttrs, BaseAttrsNode); }; // struct AllReduceAttrs /*! \brief Attributes used in allgather operators */ -struct AllGatherAttrs : public tvm::AttrsNode { +struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter { int num_workers; bool in_group; - TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") { - TVM_ATTR_FIELD(num_workers) - .describe( - "The number of workers, also the number of parts the given buffer should be chunked " - "into."); - TVM_ATTR_FIELD(in_group).describe( - "Whether the allgather operation performs in group or globally or in group as default."); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("num_workers", &AllGatherAttrs::num_workers, + "The number of workers, also the number of parts the given buffer should be " + "chunked into.") + .def_ro("in_group", &AllGatherAttrs::in_group, + "Whether the allgather operation performs in group or globally or in group as " + "default."); } + + static constexpr const char* _type_key = "relax.attrs.AllGatherAttrs"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllGatherAttrs, BaseAttrsNode); }; // struct AllGatherAttrs /*! \brief Attributes used in scatter operators */ -struct ScatterCollectiveAttrs : public tvm::AttrsNode { +struct ScatterCollectiveAttrs : public tvm::AttrsNodeReflAdapter { int num_workers; int axis; - TVM_DECLARE_ATTRS(ScatterCollectiveAttrs, "relax.attrs.ScatterCollectiveAttrs") { - TVM_ATTR_FIELD(num_workers) - .describe( - "The number of workers, also the number of parts the given buffer should be chunked " - "into."); - TVM_ATTR_FIELD(axis).describe( - "The axis of the tensor to be scattered. The tensor will be chunked along " - "this axis."); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("num_workers", &ScatterCollectiveAttrs::num_workers, + "The number of workers, also the number of parts the given buffer should be " + "chunked into.") + .def_ro("axis", &ScatterCollectiveAttrs::axis, + "The axis of the tensor to be scattered. The tensor will be chunked along " + "this axis."); } + + static constexpr const char* _type_key = "relax.attrs.ScatterCollectiveAttrs"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterCollectiveAttrs, BaseAttrsNode); }; // struct ScatterCollectiveAttrs } // namespace relax diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 2550f5652fc7..69903b26a9f0 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -50,7 +50,7 @@ void FuncAttrGetter::VisitExpr_(const CallNode* op) { if (op->attrs.defined()) { Map attrs; AttrGetter getter(&attrs); - const_cast(op->attrs.get())->VisitAttrs(&getter); + getter(op->attrs); for (const auto& pair : attrs) { if (attrs_.count(pair.first)) { int cnt = 1; @@ -350,7 +350,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin attrs = FuncAttrGetter().GetAttrs(call_node->op); } else if (call_node->attrs.defined()) { AttrGetter getter(&attrs); - const_cast(call_node->attrs.get())->VisitAttrs(&getter); + getter(call_node->attrs); } } else if (const auto* const_node = expr.as()) { if (const_node->is_scalar()) { diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 4eac04349728..b2689ee7b7fd 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -25,6 +25,7 @@ #define TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_ #include +#include #include #include #include @@ -106,7 +107,7 @@ struct MSCRBuildConfig { } }; -class AttrGetter : public AttrVisitor { +class AttrGetter : private AttrVisitor { public: /*! * \brief Get the attributes as Map @@ -114,6 +115,57 @@ class AttrGetter : public AttrVisitor { */ explicit AttrGetter(Map* attrs) : attrs_(attrs) {} + void operator()(const Attrs& attrs) { + // dispatch between new reflection and old reflection + const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); + if (attrs_tinfo->extra_info != nullptr) { + tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs); + this->VisitAny(String(field_info->name), field_value); + }); + } else { + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + const_cast(attrs.get())->VisitAttrs(this); + } + } + + private: + void VisitAny(String key, Any value) { + switch (value.type_index()) { + case kTVMFFINone: { + attrs_->Set(key, ""); + break; + } + case kTVMFFIBool: { + attrs_->Set(key, std::to_string(value.cast())); + break; + } + case kTVMFFIInt: { + attrs_->Set(key, std::to_string(value.cast())); + break; + } + case kTVMFFIFloat: { + attrs_->Set(key, std::to_string(value.cast())); + break; + } + case kTVMFFIDataType: { + attrs_->Set(key, runtime::DLDataTypeToString(value.cast())); + } + case kTVMFFIStr: { + attrs_->Set(key, value.cast()); + break; + } + default: { + if (value.type_index() >= kTVMFFIStaticObjectBegin) { + attrs_->Set(key, StringUtils::ToString(value.cast())); + } else { + LOG(FATAL) << "Unsupported type: " << value.type_index(); + } + break; + } + } + } + void Visit(const char* key, double* value) final { attrs_->Set(key, std::to_string(*value)); } void Visit(const char* key, int64_t* value) final { attrs_->Set(key, std::to_string(*value)); } diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 2290403d3730..70e94b044044 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -22,6 +22,7 @@ * \file node/reflection.cc */ #include +#include #include #include #include @@ -104,8 +105,22 @@ ffi::Any ReflectionVTable::GetAttr(Object* self, const String& field_name) const ret = self->GetTypeKey(); success = true; } else if (!self->IsInstance()) { - VisitAttrs(self, &getter); - success = getter.found_ref_object || ret != nullptr; + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); + success = false; + // use new reflection mechanism + if (type_info->extra_info != nullptr) { + ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + if (field_name.compare(field_info->name) == 0) { + ffi::reflection::FieldGetter field_getter(field_info); + ret = field_getter(self); + success = true; + } + }); + } else { + // legacy reflection mechanism, will be phased out in the future + VisitAttrs(self, &getter); + success = getter.found_ref_object || ret != nullptr; + } } else { // specially handle dict attr DictAttrsNode* dnode = static_cast(self); @@ -149,7 +164,16 @@ std::vector ReflectionVTable::ListAttrNames(Object* self) const { dir.names = &names; if (!self->IsInstance()) { - VisitAttrs(self, &dir); + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index()); + if (type_info->extra_info != nullptr) { + // use new reflection mechanism + ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + names.push_back(std::string(field_info->name.data, field_info->name.size)); + }); + } else { + // legacy reflection mechanism, will be phased out in the future + VisitAttrs(self, &dir); + } } else { // specially handle dict attr DictAttrsNode* dnode = static_cast(self); @@ -288,8 +312,20 @@ void NodeListAttrNames(ffi::PackedArgs args, ffi::Any* ret) { // args format: // key1, value1, ..., key_n, value_n void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { + // dispatch between new reflection and old reflection auto type_key = args[0].cast(); - *rv = ReflectionVTable::Global()->CreateObject(type_key, args.Slice(1)); + int32_t type_index; + TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + if (type_info->extra_info != nullptr) { + auto fcreate_object = ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs"); + fcreate_object.CallPacked(args, rv); + return; + } else { + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + *rv = ReflectionVTable::Global()->CreateObject(type_key, args.Slice(1)); + } } TVM_FFI_REGISTER_GLOBAL("node.NodeGetAttr").set_body_packed(NodeGetAttr); @@ -332,13 +368,31 @@ class GetAttrKeyByAddressVisitor : public AttrVisitor { } // anonymous namespace Optional GetAttrKeyByAddress(const Object* object, const void* attr_address) { - GetAttrKeyByAddressVisitor visitor(attr_address); - ReflectionVTable::Global()->VisitAttrs(const_cast(object), &visitor); - const char* key = visitor.GetKey(); - if (key == nullptr) { - return std::nullopt; + // NOTE: reflection dispatch for both new and legacy reflection mechanism + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(object->type_index()); + if (tinfo->extra_info != nullptr) { + Optional result; + // visit fields with the new reflection + ffi::reflection::ForEachFieldInfoWithEarlyStop(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(object); + const void* field_addr = reinterpret_cast(object) + field_info->offset; + if (field_addr == attr_address) { + result = String(field_info->name); + return true; + } + return false; + }); + return result; } else { - return String(key); + // TODO(tvm-team): remove this path once all objects are transitioned to the new reflection + GetAttrKeyByAddressVisitor visitor(attr_address); + ReflectionVTable::Global()->VisitAttrs(const_cast(object), &visitor); + const char* key = visitor.GetKey(); + if (key == nullptr) { + return std::nullopt; + } else { + return String(key); + } } } diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 986a2d044524..08fc32ad3aae 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -62,7 +63,7 @@ inline std::string Base64Encode(std::string s) { } // indexer to index all the nodes -class NodeIndexer : public AttrVisitor { +class NodeIndexer : private AttrVisitor { public: std::unordered_map node_index_{{Any(nullptr), 0}}; std::vector node_list_{Any(nullptr)}; @@ -133,10 +134,26 @@ class NodeIndexer : public AttrVisitor { Object* n = const_cast(opt_object.value()); // if the node already have repr bytes, no need to visit Attrs. if (!reflection_->GetReprBytes(n, nullptr)) { - reflection_->VisitAttrs(n, this); + this->VisitObjectFields(n); } } } + + void VisitObjectFields(Object* obj) { + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); + if (tinfo->extra_info != nullptr) { + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + // only make index for ObjectRef + if (field_value.as()) { + this->MakeIndex(field_value); + } + }); + } else { + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + reflection_->VisitAttrs(obj, this); + } + } }; // use map so attributes are ordered. @@ -211,7 +228,7 @@ struct JSONNode { // Helper class to populate the json node // using the existing index. -class JSONAttrGetter : public AttrVisitor { +class JSONAttrGetter : private AttrVisitor { public: const std::unordered_map* node_index_; const std::unordered_map* tensor_index_; @@ -296,7 +313,7 @@ class JSONAttrGetter : public AttrVisitor { // do not need to print additional things once we have repr bytes. if (!reflection_->GetReprBytes(n, &(node_->repr_bytes))) { // recursively index normal object. - reflection_->VisitAttrs(n, this); + this->VisitObjectFields(n); } } else { // handling primitive types @@ -327,9 +344,59 @@ class JSONAttrGetter : public AttrVisitor { } } } + + void VisitObjectFields(Object* obj) { + // dispatch between new reflection and old reflection + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); + if (tinfo->extra_info != nullptr) { + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + String field_name(field_info->name); + switch (field_value.type_index()) { + case ffi::TypeIndex::kTVMFFINone: { + node_->attrs[field_name] = "null"; + break; + } + case ffi::TypeIndex::kTVMFFIBool: + case ffi::TypeIndex::kTVMFFIInt: { + int64_t value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIFloat: { + double value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIDataType: { + DataType value(field_value.cast()); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFINDArray: { + runtime::NDArray value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + default: { + if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + ObjectRef obj = field_value.cast(); + this->Visit(field_info->name.data, &obj); + break; + } else { + LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); + } + } + } + }); + } else { + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + reflection_->VisitAttrs(obj, this); + } + } }; -class FieldDependencyFinder : public AttrVisitor { +class FieldDependencyFinder : private AttrVisitor { public: JSONNode* jnode_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); @@ -385,14 +452,31 @@ class FieldDependencyFinder : public AttrVisitor { jnode_ = jnode; if (auto opt_object = node.as()) { Object* n = const_cast(opt_object.value()); - reflection_->VisitAttrs(n, this); + this->VisitObjectFields(n); + } + } + + void VisitObjectFields(Object* obj) { + // dispatch between new reflection and old reflection + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); + if (tinfo->extra_info != nullptr) { + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + if (auto opt_object = field_value.as()) { + ObjectRef obj = *std::move(opt_object); + this->Visit(field_info->name.data, &obj); + } + }); + } else { + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + reflection_->VisitAttrs(obj, this); } } }; // Helper class to set the attributes of a node // from given json node. -class JSONAttrSetter : public AttrVisitor { +class JSONAttrSetter : private AttrVisitor { public: const std::vector* node_list_; const std::vector* tensor_list_; @@ -543,7 +627,62 @@ class JSONAttrSetter : public AttrVisitor { if (jnode->repr_bytes.length() > 0 || reflection_->GetReprBytes(n, nullptr)) { return; } - reflection_->VisitAttrs(n, this); + this->SetObjectFields(n); + } + } + + void SetObjectFields(Object* obj) { + // dispatch between new reflection and old reflection + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); + if (tinfo->extra_info != nullptr) { + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + this->SetObjectField(obj, field_info); + }); + } else { + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + reflection_->VisitAttrs(obj, this); + } + } + + void SetObjectField(Object* obj, const TVMFFIFieldInfo* field_info) { + ffi::reflection::FieldSetter setter(field_info); + switch (field_info->field_static_type_index) { + case ffi::TypeIndex::kTVMFFIBool: + case ffi::TypeIndex::kTVMFFIInt: { + Optional value; + this->Visit(field_info->name.data, &value); + setter(obj, value); + break; + } + case ffi::TypeIndex::kTVMFFIFloat: { + Optional value; + this->Visit(field_info->name.data, &value); + setter(obj, value); + break; + } + case ffi::TypeIndex::kTVMFFIDataType: { + DataType value; + this->Visit(field_info->name.data, &value); + setter(obj, value); + break; + } + case ffi::TypeIndex::kTVMFFINDArray: { + runtime::NDArray value; + this->Visit(field_info->name.data, &value); + setter(obj, value); + break; + } + default: { + if (field_info->field_static_type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + ObjectRef value; + this->Visit(field_info->name.data, &value); + setter(obj, value); + break; + } else { + LOG(FATAL) << "Unsupported type: " << field_info->field_static_type_index; + } + } } } }; diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 6b19fb5355bb..d1163269a8b3 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -193,9 +193,13 @@ bool SEqualReducer::AnyEqual(const ffi::Any& lhs, const ffi::Any& rhs, if (paths) { return operator()(lhs.cast(), rhs.cast(), paths.value()); } else { - return operator()(lhs.cast(), rhs.cast()); + ObjectRef lhs_obj = lhs.cast(); + ObjectRef rhs_obj = rhs.cast(); + bool result = operator()(lhs_obj, rhs_obj); + return result; } } + if (ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(lhs)->v_uint64 == ffi::details::AnyUnsafe::TVMFFIAnyPtrFromAny(rhs)->v_uint64) { return true; diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index f7df28bf716f..9aa693f58b56 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -56,7 +57,7 @@ using JSONGraphObjectPtr = std::shared_ptr; * \brief Helper class to extract all attributes of a certain op and save them * into text format. */ -class OpAttrExtractor : public AttrVisitor { +class OpAttrExtractor : private AttrVisitor { public: explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {} @@ -150,11 +151,58 @@ class OpAttrExtractor : public AttrVisitor { void Extract(Object* node) { if (node) { - reflection_->VisitAttrs(node, this); + this->VisitObjectFields(node); } } private: + void VisitObjectFields(Object* obj) { + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); + if (tinfo->extra_info != nullptr) { + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + switch (field_value.type_index()) { + case ffi::TypeIndex::kTVMFFINone: { + SetNodeAttr(field_info->name.data, {""}); + break; + } + case ffi::TypeIndex::kTVMFFIBool: + case ffi::TypeIndex::kTVMFFIInt: { + int64_t value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIFloat: { + double value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFIDataType: { + DataType value(field_value.cast()); + this->Visit(field_info->name.data, &value); + break; + } + case ffi::TypeIndex::kTVMFFINDArray: { + runtime::NDArray value = field_value.cast(); + this->Visit(field_info->name.data, &value); + break; + } + default: { + if (field_value.type_index() >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + ObjectRef obj = field_value.cast(); + this->Visit(field_info->name.data, &obj); + break; + } + LOG(FATAL) << "Unsupported type: " << field_value.GetTypeKey(); + } + } + }); + } else { + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + reflection_->VisitAttrs(obj, this); + } + } + JSONGraphObjectPtr node_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); }; diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index 2f6314221ecc..c73cf672abd1 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -27,6 +27,12 @@ namespace relax { /* relax.ccl.allreduce */ TVM_REGISTER_NODE_TYPE(AllReduceAttrs); +TVM_FFI_STATIC_INIT_BLOCK({ + AllReduceAttrs::RegisterReflection(); + AllGatherAttrs::RegisterReflection(); + ScatterCollectiveAttrs::RegisterReflection(); +}); + Expr allreduce(Expr x, String op_type, bool in_group) { ObjectPtr attrs = make_object(); attrs->op_type = std::move(op_type); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 8c72eb4ef318..d906c1baf54d 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -17,6 +17,8 @@ * under the License. */ #include +#include +#include #include #include @@ -104,9 +106,22 @@ void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, ffi::TypedFunction is_var) { - class Visitor : public AttrVisitor { + class Visitor : private AttrVisitor { public: - inline void operator()(ObjectRef obj) { Visit("", &obj); } + void operator()(ObjectRef obj) { + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); + if (tinfo->extra_info != nullptr) { + // visit fields with the new reflection + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + this->RecursiveVisitAny(&field_value); + }); + } else { + // NOTE: legacy VisitAttrs mechanism + // TODO(tvm-team): remove this once all objects are transitioned to the new reflection + this->Visit("", &obj); + } + } private: void RecursiveVisitAny(ffi::Any* value) { diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 82c2083044ec..3dd6cab0526e 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -25,12 +26,30 @@ namespace tvm { namespace script { namespace printer { -class AttrPrinter : public tvm::AttrVisitor { +class AttrPrinter : private AttrVisitor { public: explicit AttrPrinter(ObjectPath p, const IRDocsifier& d, Array* keys, Array* values) : p(std::move(p)), d(d), keys(keys), values(values) {} + void operator()(const tvm::Attrs& attrs) { + // NOTE: reflection dispatch for both new and legacy reflection mechanism + const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); + if (attrs_tinfo->extra_info != nullptr) { + LOG(INFO) << "Using new reflection to print attrs" << String(attrs_tinfo->type_key); + // new printing mechanism using the new reflection + ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) { + String field_name = String(field_info->name); + Any field_value = ffi::reflection::FieldGetter(field_info)(attrs); + keys->push_back(field_name); + values->push_back(d->AsDoc(field_value, p->Attr(field_name))); + }); + } else { + const_cast(attrs.get())->VisitAttrs(this); + } + } + + private: void Visit(const char* key, double* value) final { keys->push_back(key); values->push_back(LiteralDoc::Float(*value, p->Attr(key))); @@ -235,8 +254,7 @@ Optional PrintHintOnDevice(const relax::Call& n, const ObjectPath& n_p, Array kwargs_values; ICHECK(n->attrs.defined()); if (n->attrs.as()) { - AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values); - const_cast(n->attrs.get())->VisitAttrs(&printer); + AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs); args.push_back(Relax(d, "device")->Call({}, kwargs_keys, kwargs_values)); } return Relax(d, "hint_on_device")->Call(args); @@ -355,8 +373,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) d->AsDoc(kv.second, n_p->Attr("attrs")->Attr(kv.first))); } } else { - AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values); - const_cast(n->attrs.get())->VisitAttrs(&printer); + AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs); } } // Step 4. Print type_args diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index d0d9a35db83e..106e7f985b65 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -34,23 +35,29 @@ namespace tvm { // Attrs used to python API -struct TestAttrs : public AttrsNode { +struct TestAttrs : public AttrsNodeReflAdapter { int axis; String name; Array padding; TypedEnvFunc func; - TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") { - TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( - "axis field"); - TVM_ATTR_FIELD(name).describe("name"); - TVM_ATTR_FIELD(padding).describe("padding of input").set_default(Array({0, 0})); - TVM_ATTR_FIELD(func) - .describe("some random env function") - .set_default(TypedEnvFunc(nullptr)); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("axis", &TestAttrs::axis, "axis field", refl::DefaultValue(10)) + .def_ro("name", &TestAttrs::name, "name") + .def_ro("padding", &TestAttrs::padding, "padding of input", + refl::DefaultValue(Array({0, 0}))) + .def_ro("func", &TestAttrs::func, "some random env function", + refl::DefaultValue(TypedEnvFunc(nullptr))); } + + static constexpr const char* _type_key = "attrs.TestAttrs"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestAttrs, BaseAttrsNode); }; +TVM_FFI_STATIC_INIT_BLOCK({ TestAttrs::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(TestAttrs); TVM_FFI_REGISTER_GLOBAL("testing.GetShapeSize").set_body_typed([](ffi::Shape shape) { diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py index ce8ac3e4ba38..d61538ac2512 100644 --- a/tests/python/ir/test_ir_attrs.py +++ b/tests/python/ir/test_ir_attrs.py @@ -20,18 +20,22 @@ def test_make_attrs(): - with pytest.raises(AttributeError): + with pytest.raises(TypeError): x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx") - with pytest.raises(AttributeError): - x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx") - x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) assert x.name == "xx" assert x.padding[0].value == 3 assert x.padding[1].value == 4 assert x.axis == 10 + x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) + y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 5)) + z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 5)) + assert not tvm.ir.structural_equal(x, y) + assert tvm.ir.structural_equal(x, x) + assert tvm.ir.structural_equal(y, z) + def test_dict_attrs(): dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0, 0))