diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 7cf7543f482d..a43f2e83b500 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -303,12 +303,54 @@ typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result); */ typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value); +/*! + * \brief bitmask of the field. + */ +#ifdef __cplusplus +enum TVMFFIFieldFlagBitMask : int32_t { +#else +typedef enum { +#endif + /*! \brief The field is writable. */ + TVMFFIFieldFlagBitMaskWritable = 1 << 0, + /*! \brief The field has default value. */ + TVMFFIFieldFlagBitMaskHasDefault = 1 << 1, + /*! \brief The field is a static method. */ + TVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2, +#ifdef __cplusplus +}; +#else +} TVMFFIFieldFlagBitMask; +#endif + /*! * \brief Information support for optional object reflection. */ typedef struct { /*! \brief The name of the field. */ TVMFFIByteArray name; + /*! \brief The docstring about the field. */ + TVMFFIByteArray doc; + /*! + * \brief bitmask flags of the field. + */ + int64_t flags; + /*! + * \brief Byte offset of the field. + */ + int64_t byte_offset; + /*! \brief The getter to access the field. */ + TVMFFIFieldGetter getter; + /*! + * \brief The setter to access the field. + * \note The setter is set even if the field is readonly for serialization. + */ + TVMFFIFieldSetter setter; + /*! + * \brief The default value of the field, this field hold AnyView, + * valid when flags set TVMFFIFieldFlagBitMaskHasDefault + */ + TVMFFIAny default_value; /*! * \brief Records the static type kind of the field. * @@ -317,28 +359,18 @@ typedef struct { * - TVMFFITypeIndex::kTVMFFIObject for general objects * - The value is nullable when kTVMFFIObject is chosen * - static object type kinds such as Map, Dict, String - * - POD type index + * - POD type index, note it does not give information about storage size of the field. * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info * about the field. * - * \note This information is helpful in designing serializer - * of the field. As it helps to narrow down the type of the - * object. It also helps to provide opportunities to enable - * short-cut access to the field. + * When the value is a type index of Object type, the field is storaged as an ObjectRef. + * + * \note This information maybe helpful in designing serializer. + * As it helps to narrow down the field type so we don't have to + * print type_key for cases like POD types. + * It also helps to provide opportunities to enable short-cut getter to ObjectRef fields. */ int32_t field_static_type_index; - /*! - * \brief Mark whether field is readonly. - */ - int32_t readonly; - /*! - * \brief Byte offset of the field. - */ - int64_t byte_offset; - /*! \brief The getter to access the field. */ - TVMFFIFieldGetter getter; - /*! \brief The setter to access the field. */ - TVMFFIFieldSetter setter; } TVMFFIFieldInfo; /*! @@ -347,11 +379,15 @@ typedef struct { typedef struct { /*! \brief The name of the field. */ TVMFFIByteArray name; + /*! \brief The docstring about the method. */ + TVMFFIByteArray doc; + /*! \brief bitmask flags of the method. */ + int64_t flags; /*! - * \brief The method wrapped as Function - * \note The first argument to the method is always the self. + * \brief The method wrapped as ffi::Function, stored as AnyView. + * \note The first argument to the method is always the self for instance methods. */ - TVMFFIObjectHandle method; + TVMFFIAny method; } TVMFFIMethodInfo; /*! @@ -379,6 +415,7 @@ typedef struct { int32_t num_fields; /*! \brief number of reflection acccesible methods. */ int32_t num_methods; + /*! \brief The reflection field information. */ TVMFFIFieldInfo* fields; /*! \brief The reflection method. */ @@ -522,12 +559,29 @@ TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol) // Section: Type reflection support APIs //------------------------------------------------------------ /*! - * \brief Register type field information for rutnime reflection. + * \brief Register type field information for runtime reflection. * \param type_index The type index * \param info The field info to be registered. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info); +TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info); + +/*! + * \brief Register type method information for runtime reflection. + * \param type_index The type index + * \param info The method info to be registered. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info); + +/*! + * \brief Get dynamic type info by type index. + * + * \param type_index The type index + * \param result The output type information + * \return The type info + */ +TVM_FFI_DLL const TVMFFITypeInfo* TVMFFITypeGetMethod(int32_t type_index); //------------------------------------------------------------ // Section: DLPack support APIs @@ -638,7 +692,7 @@ TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lin * * \return 0 if success, -1 if error occured */ -TVM_FFI_DLL int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key, +TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t static_type_index, int32_t type_depth, int32_t num_child_slots, int32_t child_slots_can_overflow, diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 128c67830e84..eb3f390bcdec 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -836,14 +836,14 @@ class Function::Registry { if constexpr (std::is_base_of_v) { auto fwrap = [f](T target, Args... params) -> R { // call method pointer - return (target.*f)(params...); + return (target.*f)(std::forward(params)...); }; return Register(ffi::Function::FromTyped(fwrap, name_)); } if constexpr (std::is_base_of_v) { auto fwrap = [f](const T* target, Args... params) -> R { // call method pointer - return (const_cast(target)->*f)(params...); + return (const_cast(target)->*f)(std::forward(params)...); }; return Register(ffi::Function::FromTyped(fwrap, name_)); } @@ -857,14 +857,14 @@ class Function::Registry { if constexpr (std::is_base_of_v) { auto fwrap = [f](const T target, Args... params) -> R { // call method pointer - return (target.*f)(params...); + return (target.*f)(std::forward(params)...); }; return Register(ffi::Function::FromTyped(fwrap, name_)); } if constexpr (std::is_base_of_v) { auto fwrap = [f](const T* target, Args... params) -> R { // call method pointer - return (target->*f)(params...); + return (target->*f)(std::forward(params)...); }; return Register(ffi::Function::FromTyped(fwrap, name_)); } diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index e86689ebe23a..72e6f0a1f8ec 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -103,6 +103,9 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index); * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO * It is still OK to sub-class a terminal object type T and construct it using make_object. * But IsInstance check will only show that the object type is T(instead of the sub-class). + * - _type_mutable: + * Whether we would like to expose cast to non-constant pointer + * ObjectType* from Any/AnyView. By default, we set to false so it is not exposed. * * The following two fields are necessary for base classes that can be sub-classed. * @@ -191,6 +194,7 @@ class Object { // Default object type properties for sub-classes static constexpr bool _type_final = false; + static constexpr bool _type_mutable = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; // NOTE: static type index field of the class @@ -546,7 +550,7 @@ struct ObjectPtrEqual { "Need to set _type_child_slots when parent specifies it."); \ TVMFFIByteArray type_key{TypeName::_type_key, \ std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ + static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ &type_key, TypeName::_type_index, TypeName::_type_depth, TypeName::_type_child_slots, \ TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ return tindex; \ @@ -576,7 +580,7 @@ struct ObjectPtrEqual { "Need to set _type_child_slots when parent specifies it."); \ TVMFFIByteArray type_key{TypeName::_type_key, \ std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ + static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ &type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ return tindex; \ diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index 766b9b809958..a5ab6f4fe87d 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -29,10 +29,31 @@ #include #include +#include namespace tvm { namespace ffi { -namespace details { +/*! \brief Reflection namespace */ +namespace reflection { + +/*! \brief Trait that can be used to set field info */ +struct FieldInfoTrait {}; + +/*! + * \brief Trait that can be used to set field default value + */ +class DefaultValue : public FieldInfoTrait { + public: + explicit DefaultValue(Any value) : value_(value) {} + + void Apply(TVMFFIFieldInfo* info) const { + info->default_value = AnyView(value_).CopyToTVMFFIAny(); + info->flags |= TVMFFIFieldFlagBitMaskHasDefault; + } + + private: + Any value_; +}; /*! * \brief Get the byte offset of a class member field. @@ -50,37 +71,108 @@ inline int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); } -struct ReflectionDefFinish {}; - class ReflectionDef { public: - explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {} + explicit ReflectionDef(int32_t type_index, const char* type_key) + : type_index_(type_index), type_key_(type_key) {} + + /*! + * \brief Define a readonly field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + ReflectionDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) { + RegisterField(name, field_ptr, false, std::forward(extra)...); + return *this; + } - template - ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) { - RegisterField(name, field_ptr, true); + /*! + * \brief Define a read-write field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + ReflectionDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) { + RegisterField(name, field_ptr, true, std::forward(extra)...); return *this; } - template - ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) { - RegisterField(name, field_ptr, false); + /*! + * \brief Define a method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + ReflectionDef& def(const char* name, Func&& func, Extra&&... extra) { + RegisterMethod(name, false, std::forward(func), std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a static method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + ReflectionDef& def_static(const char* name, Func&& func, Extra&&... extra) { + RegisterMethod(name, true, std::forward(func), std::forward(extra)...); return *this; } private: - template - void RegisterField(const char* name, T Class::*field_ptr, bool readonly) { + template + void RegisterField(const char* name, T Class::*field_ptr, bool writable, + ExtraArgs&&... extra_args) { TVMFFIFieldInfo info; info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; info.field_static_type_index = TypeToFieldStaticTypeIndex::value; // store byte offset and setter, getter // so the same setter can be reused for all the same type info.byte_offset = GetFieldByteOffsetToObject(field_ptr); - info.readonly = readonly; + info.flags = 0; + if (writable) { + info.flags |= TVMFFIFieldFlagBitMaskWritable; + } info.getter = FieldGetter; info.setter = FieldSetter; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIRegisterTypeField(type_index_, &info)); + // initialize default value to nullptr + info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); + info.doc = TVMFFIByteArray{nullptr, 0}; + // apply field info traits + ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); + // call register + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); } template @@ -97,14 +189,70 @@ class ReflectionDef { TVM_FFI_SAFE_CALL_END(); } + template + static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { + if constexpr (std::is_base_of_v>) { + value.Apply(info); + } + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + // register a method + template + void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { + TVMFFIMethodInfo info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.doc = TVMFFIByteArray{nullptr, 0}; + info.flags = 0; + if (is_static) { + info.flags |= TVMFFIFieldFlagBitMaskIsStaticMethod; + } + // obtain the method function + Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + info.method = AnyView(method).CopyToTVMFFIAny(); + // apply method info traits + ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); + } + + template + static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) { + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + static Function GetMethod(std::string name, R (Class::*func)(Args...)) { + auto fwrap = [func](const Class* target, Args... params) -> R { + return (const_cast(target)->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + template + static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { + auto fwrap = [func](const Class* target, Args... params) -> R { + return (target->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + template + static Function GetMethod(std::string name, Func&& func) { + return ffi::Function::FromTyped(std::forward(func), name); + } + int32_t type_index_; + const char* type_key_; }; /*! * \brief helper function to get reflection field info by type key and field name */ -inline const TVMFFIFieldInfo* GetReflectionFieldInfo(std::string_view type_key, - const char* field_name) { +inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char* field_name) { int32_t type_index; TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); @@ -115,14 +263,18 @@ inline const TVMFFIFieldInfo* GetReflectionFieldInfo(std::string_view type_key, } } TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in " << type_key; + TVM_FFI_UNREACHABLE(); } /*! * \brief helper wrapper class to obtain a getter. */ -class ReflectionFieldGetter { +class FieldGetter { public: - explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} + explicit FieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} + + explicit FieldGetter(std::string_view type_key, const char* field_name) + : FieldGetter(GetFieldInfo(type_key, field_name)) {} Any operator()(const Object* obj_ptr) const { Any result; @@ -140,16 +292,81 @@ class ReflectionFieldGetter { const TVMFFIFieldInfo* field_info_; }; -#define TVM_FFI_REFLECTION_REG_VAR_DEF \ - static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::ReflectionDef& __TVMFFIReflectionReg +/*! + * \brief helper wrapper class to obtain a setter. + */ +class FieldSetter { + public: + explicit FieldSetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} + + explicit FieldSetter(std::string_view type_key, const char* field_name) + : FieldSetter(GetFieldInfo(type_key, field_name)) {} + + void operator()(const Object* obj_ptr, AnyView value) const { + const void* addr = reinterpret_cast(obj_ptr) + field_info_->byte_offset; + TVM_FFI_CHECK_SAFE_CALL( + field_info_->setter(const_cast(addr), reinterpret_cast(&value))); + } + + void operator()(const ObjectPtr& obj_ptr, AnyView value) const { + operator()(obj_ptr.get(), value); + } + + void operator()(const ObjectRef& obj, AnyView value) const { operator()(obj.get(), value); } + + private: + const TVMFFIFieldInfo* field_info_; +}; + +/*! + * \brief helper function to get reflection method info by type key and method name + * + * \param type_key The type key. + * \param method_name The name of the method. + * \return The method info. + */ +inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const char* method_name) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TypeInfo* info = TVMFFIGetTypeInfo(type_index); + for (int32_t i = 0; i < info->num_methods; ++i) { + if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) { + return &(info->methods[i]); + } + } + TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key; + TVM_FFI_UNREACHABLE(); +} + +/*! + * \brief helper function to get reflection method function by method info + * + * \param type_key The type key. + * \param method_name The name of the method. + * \return The method function. + */ +inline Function GetMethod(std::string_view type_key, const char* method_name) { + const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name); + return AnyView::CopyFromTVMFFIAny(info->method).cast(); +} + +#define TVM_FFI_REFLECTION_REG_VAR_DEF \ + static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::reflection::ReflectionDef& \ + __TVMFFIReflectionReg /*! * helper macro to define a reflection definition for an object */ -#define TVM_FFI_REFLECTION_DEF(TypeName) \ - TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::ffi::details::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex()) -} // namespace details +#define TVM_FFI_REFLECTION_DEF(TypeName) \ + TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ + ::tvm::ffi::reflection::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex(), \ + TypeName::_type_key) + +} // namespace reflection + +/*! \brief Shortcut to the reflection namespace */ +namespace refl = reflection; } // namespace ffi } // namespace tvm #endif // TVM_FFI_REFLECTION_REFLECTION_H_ diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index c3eceff90590..19df2e8e3dcf 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -401,7 +401,6 @@ TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { template struct TypeTraits : public TypeTraitsBase { // NOTE: only enable implicit conversion into AnyView - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr; static constexpr bool storage_enabled = false; static TVM_FFI_INLINE void CopyToAnyView(const char src[N], TVMFFIAny* result) { @@ -417,7 +416,6 @@ struct TypeTraits : public TypeTraitsBase { template <> struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr; static constexpr bool storage_enabled = false; static TVM_FFI_INLINE void CopyToAnyView(const char* src, TVMFFIAny* result) { diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index 02c9a90edcfd..5c291b553570 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -84,6 +84,7 @@ inline constexpr bool use_default_type_traits_v = true; struct TypeTraitsBase { static constexpr bool convert_enabled = true; static constexpr bool storage_enabled = true; + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; // get mismatched type when result mismatches the trait. // this function is called after TryCastFromAnyView fails // to get more detailed type information in runtime @@ -588,17 +589,18 @@ struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase -struct TypeTraits>> +struct TypeTraits>> : public TypeTraitsBase { - static TVM_FFI_INLINE void CopyToAnyView(const TObject* src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(TObject* src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); result->type_index = obj_ptr->type_index; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); result->v_obj = obj_ptr; } - static TVM_FFI_INLINE void MoveToAny(const TObject* src, TVMFFIAny* result) { + static TVM_FFI_INLINE void MoveToAny(TObject* src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); result->type_index = obj_ptr->type_index; TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); @@ -612,11 +614,17 @@ struct TypeTraits(src->type_index); } - static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + if constexpr (!std::is_const_v) { + static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); + } return details::ObjectUnsafe::RawObjectPtrFromUnowned(src->v_obj); } - static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { + if constexpr (!std::is_const_v) { + static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); + } if (CheckAnyStrict(src)) return CopyFromAnyViewAfterCheck(src); return std::nullopt; } diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 63ec68790e57..6ce149a7c9f3 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -54,6 +54,8 @@ class TypeTable { std::vector type_acenstors_data; /*! \brief type fields informaton */ std::vector type_fields_data; + /*! \brief type methods informaton */ + std::vector type_methods_data; // NOTE: the indices in [index, index + num_reserved_slots) are // reserved for the child-class of this type. /*! \brief Total number of slots reserved for the type and its children. */ @@ -186,12 +188,29 @@ class TypeTable { Entry* entry = GetTypeEntry(type_index); TVMFFIFieldInfo field_data = *info; field_data.name = this->CopyString(info->name); + field_data.doc = this->CopyString(info->doc); + if (info->flags & TVMFFIFieldFlagBitMaskHasDefault) { + field_data.default_value = + this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value)).CopyToTVMFFIAny(); + } else { + field_data.default_value = AnyView(nullptr).CopyToTVMFFIAny(); + } entry->type_fields_data.push_back(field_data); // refresh ptr as the data can change entry->fields = entry->type_fields_data.data(); entry->num_fields = static_cast(entry->type_fields_data.size()); } + void RegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info) { + Entry* entry = GetTypeEntry(type_index); + TVMFFIMethodInfo method_data = *info; + method_data.name = this->CopyString(info->name); + method_data.doc = this->CopyString(info->doc); + method_data.method = this->CopyAny(AnyView::CopyFromTVMFFIAny(info->method)).CopyToTVMFFIAny(); + entry->type_methods_data.push_back(method_data); + entry->methods = entry->type_methods_data.data(); + entry->num_methods = static_cast(entry->type_methods_data.size()); + } void Dump(int min_children_count) { std::vector num_children(type_table_.size(), 0); // expected child slots compute the expected slots @@ -262,16 +281,25 @@ class TypeTable { } TVMFFIByteArray CopyString(TVMFFIByteArray str) { - std::unique_ptr val = std::make_unique(str.data, str.size); - TVMFFIByteArray c_val{val->data(), val->length()}; - string_pool_.emplace_back(std::move(val)); + if (str.size == 0) { + return TVMFFIByteArray{nullptr, 0}; + } + String val = String(str.data, str.size); + TVMFFIByteArray c_val{val.data(), val.length()}; + any_pool_.emplace_back(std::move(val)); return c_val; } + AnyView CopyAny(Any val) { + AnyView view = AnyView(val); + any_pool_.emplace_back(std::move(val)); + return view; + } + int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin}; std::vector> type_table_; std::unordered_map type_key2index_; - std::vector> string_pool_; + std::vector any_pool_; }; } // namespace ffi } // namespace tvm @@ -288,13 +316,19 @@ int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) { TVM_FFI_SAFE_CALL_END(); } -int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { +int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info) { TVM_FFI_SAFE_CALL_BEGIN(); tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info); TVM_FFI_SAFE_CALL_END(); } -int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key, int32_t static_type_index, +int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::TypeTable::Global()->RegisterTypeMethod(type_index, info); + TVM_FFI_SAFE_CALL_END(); +} + +int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t static_type_index, int32_t type_depth, int32_t num_child_slots, int32_t child_slots_can_overflow, int32_t parent_type_index) { TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); @@ -302,7 +336,7 @@ int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key, int32_t stati return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex( s_type_key, static_type_index, type_depth, num_child_slots, child_slots_can_overflow, parent_type_index); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetOrAllocTypeIndex); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFITypeGetOrAllocIndex); } const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) { diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc index 3ad81cd11803..eea18c7c644f 100644 --- a/ffi/tests/cpp/test_any.cc +++ b/ffi/tests/cpp/test_any.cc @@ -232,6 +232,9 @@ TEST(Any, Object) { EXPECT_EQ(v1.use_count(), 3); EXPECT_TRUE(any2.as().has_value()); + any2 = const_cast(v1_ptr); + EXPECT_TRUE(any2.as().has_value()); + // convert to raw opaque ptr void* raw_v1_ptr = const_cast(v1_ptr); any2 = raw_v1_ptr; diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc index cb42f32c6cfe..321af7ae16ac 100644 --- a/ffi/tests/cpp/test_array.cc +++ b/ffi/tests/cpp/test_array.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include "./testing_object.h" diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc index fec167d257e3..76e8d35a99cd 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -20,33 +20,108 @@ #include #include #include +#include #include "./testing_object.h" namespace { - using namespace tvm::ffi; using namespace tvm::ffi::testing; +TVM_FFI_REFLECTION_DEF(TFloatObj) + .def_rw("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) + .def("sub", [](const TFloatObj* self, double other) -> double { return self->value - other; }) + .def("add", &TFloatObj::Add, "add method"); + +TVM_FFI_REFLECTION_DEF(TIntObj) + .def_ro("value", &TIntObj::value) + .def_static("static_add", &TInt::StaticAdd, "static add method"); + +TVM_FFI_REFLECTION_DEF(TPrimExprObj) + .def_ro("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) + .def_ro("value", &TPrimExprObj::value, "value field", refl::DefaultValue(0)) + .def("sub", [](TPrimExprObj* self, double other) -> double { + // this is ok because TPrimExprObj is declared asmutable + return self->value - other; + }); + struct A : public Object { int64_t x; int64_t y; }; +TVM_FFI_REFLECTION_DEF(A).def_ro("x", &A::x).def_rw("y", &A::y); + TEST(Reflection, GetFieldByteOffset) { - EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::x), sizeof(TVMFFIObject)); - EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::y), 8 + sizeof(TVMFFIObject)); - EXPECT_EQ(details::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject)); + EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::x), sizeof(TVMFFIObject)); + EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::y), 8 + sizeof(TVMFFIObject)); + EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject)); } TEST(Reflection, FieldGetter) { ObjectRef a = TInt(10); - details::ReflectionFieldGetter getter(details::GetReflectionFieldInfo("test.Int", "value")); + reflection::FieldGetter getter("test.Int", "value"); EXPECT_EQ(getter(a).cast(), 10); ObjectRef b = TFloat(10.0); - details::ReflectionFieldGetter getter_float( - details::GetReflectionFieldInfo("test.Float", "value")); + reflection::FieldGetter getter_float("test.Float", "value"); EXPECT_EQ(getter_float(b).cast(), 10.0); } + +TEST(Reflection, FieldSetter) { + ObjectRef a = TFloat(10.0); + reflection::FieldSetter setter("test.Float", "value"); + setter(a, 20.0); + EXPECT_EQ(a.as()->value, 20.0); +} + +TEST(Reflection, FieldInfo) { + const TVMFFIFieldInfo* info_int = reflection::GetFieldInfo("test.Int", "value"); + EXPECT_FALSE(info_int->flags & TVMFFIFieldFlagBitMaskHasDefault); + EXPECT_FALSE(info_int->flags & TVMFFIFieldFlagBitMaskWritable); + EXPECT_EQ(Bytes(info_int->doc).operator std::string(), ""); + + const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", "value"); + EXPECT_EQ(info_float->default_value.v_float64, 10.0); + EXPECT_TRUE(info_float->flags & TVMFFIFieldFlagBitMaskHasDefault); + EXPECT_TRUE(info_float->flags & TVMFFIFieldFlagBitMaskWritable); + EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value field"); + + const TVMFFIFieldInfo* info_prim_expr_dtype = reflection::GetFieldInfo("test.PrimExpr", "dtype"); + AnyView default_value = AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value); + EXPECT_EQ(default_value.cast(), "float"); + EXPECT_EQ(default_value.as().value().use_count(), 2); + EXPECT_TRUE(info_prim_expr_dtype->flags & TVMFFIFieldFlagBitMaskHasDefault); + EXPECT_FALSE(info_prim_expr_dtype->flags & TVMFFIFieldFlagBitMaskWritable); + EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype field"); +} + +TEST(Reflection, MethodInfo) { + const TVMFFIMethodInfo* info_int_static_add = reflection::GetMethodInfo("test.Int", "static_add"); + EXPECT_TRUE(info_int_static_add->flags & TVMFFIFieldFlagBitMaskIsStaticMethod); + EXPECT_EQ(Bytes(info_int_static_add->doc).operator std::string(), "static add method"); + + const TVMFFIMethodInfo* info_float_add = reflection::GetMethodInfo("test.Float", "add"); + EXPECT_FALSE(info_float_add->flags & TVMFFIFieldFlagBitMaskIsStaticMethod); + EXPECT_EQ(Bytes(info_float_add->doc).operator std::string(), "add method"); + + const TVMFFIMethodInfo* info_float_sub = reflection::GetMethodInfo("test.Float", "sub"); + EXPECT_FALSE(info_float_sub->flags & TVMFFIFieldFlagBitMaskIsStaticMethod); + EXPECT_EQ(Bytes(info_float_sub->doc).operator std::string(), ""); +} + +TEST(Reflection, CallMethod) { + Function static_int_add = reflection::GetMethod("test.Int", "static_add"); + EXPECT_EQ(static_int_add(TInt(1), TInt(2)).cast()->value, 3); + + Function float_add = reflection::GetMethod("test.Float", "add"); + EXPECT_EQ(float_add(TFloat(1), 2.0).cast(), 3.0); + + Function float_sub = reflection::GetMethod("test.Float", "sub"); + EXPECT_EQ(float_sub(TFloat(1), 2.0).cast(), -1.0); + + Function prim_expr_sub = reflection::GetMethod("test.PrimExpr", "sub"); + EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast(), -1.0); +} + } // namespace diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc index e0f69d820018..5735e86eca4d 100644 --- a/ffi/tests/cpp/test_tuple.cc +++ b/ffi/tests/cpp/test_tuple.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include "./testing_object.h" diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc index 451913c9926d..b140e7db6e4a 100644 --- a/ffi/tests/cpp/test_variant.cc +++ b/ffi/tests/cpp/test_variant.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "./testing_object.h" diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index 69a91efc46d0..8a9184884552 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -22,7 +22,6 @@ #include #include -#include #include namespace tvm { @@ -65,12 +64,12 @@ class TIntObj : public TNumberObj { TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj); }; -TVM_FFI_REFLECTION_DEF(TIntObj).def_readonly("value", &TIntObj::value); - class TInt : public TNumber { public: explicit TInt(int64_t value) { data_ = make_object(value); } + static TInt StaticAdd(TInt lhs, TInt rhs) { return TInt(lhs->value + rhs->value); } + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj); }; @@ -80,12 +79,12 @@ class TFloatObj : public TNumberObj { TFloatObj(double value) : value(value) {} + double Add(double other) const { return value + other; } + static constexpr const char* _type_key = "test.Float"; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj); }; -TVM_FFI_REFLECTION_DEF(TFloatObj).def_readonly("value", &TFloatObj::value); - class TFloat : public TNumber { public: explicit TFloat(double value) { data_ = make_object(value); } @@ -102,6 +101,7 @@ class TPrimExprObj : public Object { TPrimExprObj(std::string dtype, double value) : dtype(dtype), value(value) {} static constexpr const char* _type_key = "test.PrimExpr"; + static constexpr bool _type_mutable = true; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TPrimExprObj, Object); };