diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index eb317d2bbd72..02537df79cb4 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -70,7 +70,7 @@ class ObjAllocatorBase { * \param args The arguments. */ template - inline ObjectPtr make_object(Args&&... args) { + ObjectPtr make_object(Args&&... args) { using Handler = typename Derived::template Handler; static_assert(std::is_base_of::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); @@ -89,7 +89,7 @@ class ObjAllocatorBase { * \param args The arguments. */ template - inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { + ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { using Handler = typename Derived::template ArrayHandler; static_assert(std::is_base_of::value, "make_inplace_array can only be used to create Object"); @@ -109,7 +109,9 @@ class SimpleObjAllocator : public ObjAllocatorBase { template class Handler { public: - using StorageType = typename std::aligned_storage::type; + struct alignas(T) StorageType { + char data[sizeof(T)]; + }; template static T* New(SimpleObjAllocator*, Args&&... args) { diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index bd2f5cb9c76e..6187a74825d6 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -46,7 +46,7 @@ class DefaultValue : public FieldInfoTrait { public: explicit DefaultValue(Any value) : value_(value) {} - void Apply(TVMFFIFieldInfo* info) const { + TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->default_value = AnyView(value_).CopyToTVMFFIAny(); info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; } @@ -65,16 +65,89 @@ class DefaultValue : public FieldInfoTrait { * \returns The byteoffset */ template -inline int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { +TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { int64_t field_offset_to_class = reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); } +class ReflectionDefBase { + protected: + template + static int FieldGetter(void* field, TVMFFIAny* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + TVM_FFI_SAFE_CALL_END(); + } + + template + static int FieldSetter(void* field, const TVMFFIAny* value) { + TVM_FFI_SAFE_CALL_BEGIN(); + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); + TVM_FFI_SAFE_CALL_END(); + } + + template + static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); + } + + template + static TVM_FFI_INLINE void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { + if constexpr (std::is_base_of_v>) { + value.Apply(info); + } + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + static TVM_FFI_INLINE void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) { + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + static TVM_FFI_INLINE void ApplyExtraInfoTrait(TVMFFITypeExtraInfo* info, const T& value) { + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + template + static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...)) { + auto fwrap = [func](const Class* target, Args... params) -> R { + return (const_cast(target)->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + template + static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...) const) { + auto fwrap = [func](const Class* target, Args... params) -> R { + return (target->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, name); + } + + template + static TVM_FFI_INLINE Function GetMethod(std::string name, Func&& func) { + return ffi::Function::FromTyped(std::forward(func), name); + } +}; + template -class ObjectDef { +class ObjectDef : public ReflectionDefBase { public: - ObjectDef() : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {} + template + explicit ObjectDef(ExtraArgs&&... extra_args) + : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { + RegisterExtraInfo(std::forward(extra_args)...); + } /*! * \brief Define a readonly field. @@ -90,7 +163,7 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) { RegisterField(name, field_ptr, false, std::forward(extra)...); return *this; } @@ -109,7 +182,8 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) { + static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); RegisterField(name, field_ptr, true, std::forward(extra)...); return *this; } @@ -127,7 +201,7 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { RegisterMethod(name, false, std::forward(func), std::forward(extra)...); return *this; } @@ -145,12 +219,26 @@ class ObjectDef { * \return The reflection definition. */ template - ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { + TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { RegisterMethod(name, true, std::forward(func), std::forward(extra)...); return *this; } private: + template + void RegisterExtraInfo(ExtraArgs&&... extra_args) { + TVMFFITypeExtraInfo info; + info.total_size = sizeof(Class); + info.creator = nullptr; + info.doc = TVMFFIByteArray{nullptr, 0}; + if constexpr (std::is_default_constructible_v) { + info.creator = ObjectCreatorDefault; + } + // apply extra info traits + ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info)); + } + template void RegisterField(const char* name, T Class::*field_ptr, bool writable, ExtraArgs&&... extra_args) { @@ -178,30 +266,6 @@ class ObjectDef { TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); } - template - static int FieldGetter(void* field, TVMFFIAny* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int FieldSetter(void* field, const TVMFFIAny* value) { - TVM_FFI_SAFE_CALL_BEGIN(); - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); - TVM_FFI_SAFE_CALL_END(); - } - - template - static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); - } - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - // register a method template void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { @@ -214,41 +278,14 @@ class ObjectDef { info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; } // obtain the method function - Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + Function method = + GetMethod(std::string(type_key_) + "." + name, std::forward(func)); info.method = AnyView(method).CopyToTVMFFIAny(); // apply method info traits ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); } - template - static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) { - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - static Function GetMethod(std::string name, R (Class::*func)(Args...)) { - auto fwrap = [func](const Class* target, Args... params) -> R { - return (const_cast(target)->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - template - static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { - auto fwrap = [func](const Class* target, Args... params) -> R { - return (target->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, name); - } - - template - static Function GetMethod(std::string name, Func&& func) { - return ffi::Function::FromTyped(std::forward(func), name); - } - int32_t type_index_; const char* type_key_; }; diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 19df2e8e3dcf..dee2d89c0854 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -306,6 +306,18 @@ class String : public ObjectRef { return Bytes::memncmp(data(), other, size(), std::strlen(other)); } + /*! + * \brief Compares this to other + * + * \param other The TVMFFIByteArray to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const TVMFFIByteArray& other) const { + return Bytes::memncmp(data(), other.data, size(), other.size); + } + /*! * \brief Returns a pointer to the char array in the string. * diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 793d3e27283a..fa77e2b26401 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -315,6 +315,83 @@ class TypeTable { Map type_key2index_; std::vector any_pool_; }; + +void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) { + String type_key = args[0].cast(); + TVM_FFI_ICHECK(args.size() % 2 == 1); + + int32_t type_index; + TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index); + if (type_info == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Cannot find type `" << type_key << "`"; + } + + if (type_info->extra_info == nullptr || type_info->extra_info->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << type_key << "` does not support reflection creation"; + } + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info->extra_info->creator(&handle)); + ObjectPtr ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + + std::vector keys; + std::vector keys_found; + + for (int i = 1; i < args.size(); i += 2) { + keys.push_back(args[i].cast()); + } + keys_found.resize(keys.size(), false); + + auto search_field = [&](const TVMFFIByteArray& field_name) { + for (size_t i = 0; i < keys.size(); ++i) { + if (keys_found[i]) continue; + if (keys[i].compare(field_name) == 0) { + return i; + } + } + return keys.size(); + }; + + auto update_fields = [&](const TVMFFITypeInfo* tinfo) { + for (int i = 0; i < tinfo->num_fields; ++i) { + const TVMFFIFieldInfo* field_info = tinfo->fields + i; + size_t arg_index = search_field(field_info->name); + void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; + if (arg_index < keys.size()) { + AnyView field_value = args[arg_index * 2 + 2]; + field_info->setter(field_addr, reinterpret_cast(&field_value)); + keys_found[arg_index] = true; + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" << type_key << "`"; + } + } + }; + + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject); + for (int i = 1; i < type_info->type_depth; ++i) { + update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i])); + } + update_fields(type_info); + + for (size_t i = 0; i < keys.size(); ++i) { + if (!keys_found[i]) { + TVM_FFI_THROW(TypeError) << "Type `" << type_key << "` does not have field `" << keys[i] + << "`"; + } + } + *ret = ObjectRef(ptr); +} + +TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs); + } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/testing.cc b/ffi/src/ffi/testing.cc index 050ac28c476e..6bc7968eab06 100644 --- a/ffi/src/ffi/testing.cc +++ b/ffi/src/ffi/testing.cc @@ -17,7 +17,10 @@ * under the License. */ // This file is used for testing the FFI API. +#include +#include #include +#include #include #include @@ -26,6 +29,45 @@ namespace tvm { namespace ffi { +class TestObjectBase : public Object { + public: + int64_t v_i64; + double v_f64; + String v_str; + + int64_t AddI64(int64_t other) const { return v_i64 + other; } + + // declare as one slot, with float as overflow + static constexpr bool _type_mutable = true; + static constexpr uint32_t _type_child_slots = 1; + static constexpr const char* _type_key = "testing.TestObjectBase"; + TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjectBase, Object); +}; + +class TestObjectDerived : public TestObjectBase { + public: + Map v_map; + Array v_array; + + // declare as one slot, with float as overflow + static constexpr const char* _type_key = "testing.TestObjectDerived"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjectDerived, TestObjectBase); +}; + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + + refl::ObjectDef() + .def_rw("v_i64", &TestObjectBase::v_i64, refl::DefaultValue(10), "i64 field") + .def_ro("v_f64", &TestObjectBase::v_f64, refl::DefaultValue(10.0)) + .def_rw("v_str", &TestObjectBase::v_str, refl::DefaultValue("hello")) + .def("add_i64", &TestObjectBase::AddI64, "add_i64 method"); + + refl::ObjectDef() + .def_ro("v_map", &TestObjectDerived::v_map) + .def_ro("v_array", &TestObjectDerived::v_array); +}); + void TestRaiseError(String kind, String msg) { throw ffi::Error(kind, msg, TVM_FFI_TRACEBACK_HERE); } diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc index 64b3a6f590eb..450cb9dbcbf7 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -32,13 +32,15 @@ using namespace tvm::ffi::testing; struct A : public Object { int64_t x; int64_t y; + + static constexpr bool _type_mutable = true; }; TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_rw("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) + .def_ro("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) .def("sub", [](const TFloatObj* self, double other) -> double { return self->value - other; }) .def("add", &TFloatObj::Add, "add method"); @@ -47,7 +49,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def_static("static_add", &TInt::StaticAdd, "static add method"); refl::ObjectDef() - .def_ro("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) + .def_rw("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) .def_ro("value", &TPrimExprObj::value, "value field", refl::DefaultValue(0)) .def("sub", [](TPrimExprObj* self, double other) -> double { // this is ok because TPrimExprObj is declared asmutable @@ -89,7 +91,7 @@ TEST(Reflection, FieldInfo) { const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float", "value"); EXPECT_EQ(info_float->default_value.v_float64, 10.0); EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_TRUE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable); + EXPECT_FALSE(info_float->flags & kTVMFFIFieldFlagBitMaskWritable); EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value field"); const TVMFFIFieldInfo* info_prim_expr_dtype = reflection::GetFieldInfo("test.PrimExpr", "dtype"); @@ -97,7 +99,7 @@ TEST(Reflection, FieldInfo) { EXPECT_EQ(default_value.cast(), "float"); EXPECT_EQ(default_value.as().value().use_count(), 2); EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskHasDefault); - EXPECT_FALSE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable); + EXPECT_TRUE(info_prim_expr_dtype->flags & kTVMFFIFieldFlagBitMaskWritable); EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype field"); } diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py index 0a8b223405b9..b507064e34d9 100644 --- a/python/tvm/ffi/__init__.py +++ b/python/tvm/ffi/__init__.py @@ -30,6 +30,7 @@ from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu from .ndarray import from_dlpack, NDArray, Shape from .container import Array, Map +from . import testing __all__ = [ diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index e18d52fc8d84..50831be462ad 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -134,6 +134,52 @@ cdef extern from "tvm/ffi/c_api.h": void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) noexcept + cdef enum TVMFFIFieldFlagBitMask: + kTVMFFIFieldFlagBitMaskWritable = 1 << 0 + kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1 + kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2 + + ctypedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result) noexcept; + ctypedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value) noexcept; + ctypedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result) noexcept; + + ctypedef struct TVMFFIFieldInfo: + TVMFFIByteArray name + TVMFFIByteArray doc + TVMFFIByteArray type_schema + int64_t flags + int64_t size + int64_t alignment + int64_t offset + TVMFFIFieldGetter getter + TVMFFIFieldSetter setter + TVMFFIAny default_value + int32_t field_static_type_index + + ctypedef struct TVMFFIMethodInfo: + TVMFFIByteArray name + TVMFFIByteArray doc + TVMFFIByteArray type_schema + int64_t flags + TVMFFIAny method + + ctypedef struct TVMFFITypeExtraInfo: + TVMFFIByteArray doc + TVMFFIObjectCreator creator + int64_t total_size + + ctypedef struct TVMFFITypeInfo: + int32_t type_index + int32_t type_depth + TVMFFIByteArray type_key + const int32_t* type_acenstors + uint64_t type_key_hash + int32_t num_fields + int32_t num_methods + const TVMFFIFieldInfo* fields + const TVMFFIMethodInfo* methods + const TVMFFITypeExtraInfo* extra_info + int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, @@ -161,6 +207,7 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFINDArrayToDLPack(TVMFFIObjectHandle src, DLManagedTensor** out) nogil int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle src, DLManagedTensorVersioned** out) nogil + const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) nogil TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil diff --git a/python/tvm/ffi/cython/error.pxi b/python/tvm/ffi/cython/error.pxi index 3a19573b8f94..8da630873ede 100644 --- a/python/tvm/ffi/cython/error.pxi +++ b/python/tvm/ffi/cython/error.pxi @@ -113,6 +113,7 @@ cdef class Error(Object): def traceback(self): return bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).traceback)) + _register_object_by_index(kTVMFFIError, Error) diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 294a1246b27b..640fff7af557 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -230,6 +230,101 @@ class Function(Object): _register_object_by_index(kTVMFFIFunction, Function) +cdef class FieldGetter: + cdef TVMFFIFieldGetter getter + cdef int64_t offset + + def __call__(self, Object obj): + cdef TVMFFIAny result + cdef int c_api_ret_code + cdef void* field_ptr = ((obj).chandle) + self.offset + result.type_index = kTVMFFINone + result.v_int64 = 0 + c_api_ret_code = self.getter(field_ptr, &result) + CHECK_CALL(c_api_ret_code) + return make_ret(result) + + +cdef class FieldSetter: + cdef TVMFFIFieldSetter setter + cdef int64_t offset + + def __call__(self, Object obj, value): + cdef TVMFFIAny[1] packed_args + cdef int c_api_ret_code + cdef void* field_ptr = ((obj).chandle) + self.offset + cdef int nargs = 1 + temp_args = [] + make_args((value,), &packed_args[0], temp_args) + c_api_ret_code = self.setter(field_ptr, &packed_args[0]) + # NOTE: logic is same as check_call + # directly inline here to simplify traceback + if c_api_ret_code == 0: + return + elif c_api_ret_code == -2: + raise_existing_error() + raise move_from_last_error().py_error() + + +cdef _get_method_from_method_info(const TVMFFIMethodInfo* method): + cdef TVMFFIAny result + CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result)) + return make_ret(result) + + +def _add_class_attrs_by_reflection(int type_index, object cls): + """Decorate the class attrs by reflection""" + cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index) + cdef const TVMFFIFieldInfo* field + cdef const TVMFFIMethodInfo* method + cdef int num_fields = info.num_fields + cdef int num_methods = info.num_methods + + for i in range(num_fields): + # attach fields to the class + field = &(info.fields[i]) + getter = FieldGetter.__new__(FieldGetter) + (getter).getter = field.getter + (getter).offset = field.offset + setter = FieldSetter.__new__(FieldSetter) + (setter).setter = field.setter + (setter).offset = field.offset + if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0: + setter = None + doc = ( + py_str(PyBytes_FromStringAndSize(field.doc.data, field.doc.size)) + if field.doc.size != 0 + else None + ) + name = py_str(PyBytes_FromStringAndSize(field.name.data, field.name.size)) + setattr(cls, name, property(getter, setter, doc=doc)) + + for i in range(num_methods): + # attach methods to the class + method = &(info.methods[i]) + name = py_str(PyBytes_FromStringAndSize(method.name.data, method.name.size)) + doc = ( + py_str(PyBytes_FromStringAndSize(method.doc.data, method.doc.size)) + if method.doc.size != 0 + else None + ) + method_func = _get_method_from_method_info(method) + + if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod: + method_pyfunc = staticmethod(method_func) + else: + def method_pyfunc(self, *args): + return method_func(self, *args) + + if doc is not None: + method_pyfunc.__doc__ = doc + method_pyfunc.__name__ = name + + setattr(cls, name, method_pyfunc) + + return cls + + def _register_global_func(name, pyfunc, override): cdef TVMFFIObjectHandle chandle cdef int c_api_ret_code diff --git a/python/tvm/ffi/cython/ndarray.pxi b/python/tvm/ffi/cython/ndarray.pxi index b8534b41b38b..9dfe1222dc7e 100644 --- a/python/tvm/ffi/cython/ndarray.pxi +++ b/python/tvm/ffi/cython/ndarray.pxi @@ -23,7 +23,6 @@ _CLASS_NDARRAY = None def _set_class_ndarray(cls): global _CLASS_NDARRAY _CLASS_NDARRAY = cls - _register_object_by_index(kTVMFFINDArray, cls) cdef const char* _c_str_dltensor = "dltensor" @@ -268,6 +267,7 @@ cdef class NDArray(Object): _set_class_ndarray(NDArray) +_register_object_by_index(kTVMFFINDArray, NDArray) cdef inline object make_ret_dltensor(TVMFFIAny result): diff --git a/python/tvm/ffi/registry.py b/python/tvm/ffi/registry.py index 58df08d90c56..9302b251733b 100644 --- a/python/tvm/ffi/registry.py +++ b/python/tvm/ffi/registry.py @@ -50,6 +50,7 @@ def register(cls): if _SKIP_UNKNOWN_OBJECTS: return cls raise ValueError("Cannot find object type index for %s" % object_name) + core._add_class_attrs_by_reflection(type_index, cls) core._register_object_by_index(type_index, cls) return cls diff --git a/python/tvm/ffi/testing.py b/python/tvm/ffi/testing.py new file mode 100644 index 000000000000..843a10c896a8 --- /dev/null +++ b/python/tvm/ffi/testing.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities.""" + +from . import _ffi_api +from .core import Object +from .registry import register_object + + +@register_object("testing.TestObjectBase") +class TestObjectBase(Object): + """ + Test object base class. + """ + + +@register_object("testing.TestObjectDerived") +class TestObjectDerived(TestObjectBase): + """ + Test object derived class. + """ + + +def create_object(type_key: str, **kwargs) -> Object: + """ + Make an object by reflection. + + Parameters + ---------- + type_key : str + The type key of the object. + kwargs : dict + The keyword arguments to the object. + + Returns + ------- + obj : object + The created object. + + Note + ---- + This function is only used for testing purposes and should + not be used in other cases. + """ + args = [type_key] + for k, v in kwargs.items(): + args.append(k) + args.append(v) + return _ffi_api.MakeObjectFromPackedArgs(*args) diff --git a/tests/python/ffi/test_object.py b/tests/python/ffi/test_object.py new file mode 100644 index 000000000000..d333cbca089c --- /dev/null +++ b/tests/python/ffi/test_object.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm import ffi as tvm_ffi + + +def test_make_object(): + # with default values + obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase") + assert obj0.v_i64 == 10 + assert obj0.v_f64 == 10.0 + assert obj0.v_str == "hello" + + +def test_method(): + obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12) + assert obj0.add_i64(1) == 13 + assert type(obj0).add_i64.__doc__ == "add_i64 method" + assert type(obj0).v_i64.__doc__ == "i64 field" + + +def test_setter(): + # test setter + obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=10, v_str="hello") + assert obj0.v_i64 == 10 + obj0.v_i64 = 11 + assert obj0.v_i64 == 11 + obj0.v_str = "world" + assert obj0.v_str == "world" + + with pytest.raises(TypeError): + obj0.v_str = 1 + + with pytest.raises(TypeError): + obj0.v_i64 = "hello" + + +def test_derived_object(): + with pytest.raises(TypeError): + obj0 = tvm_ffi.testing.create_object("testing.TestObjectDerived") + + v_map = tvm_ffi.convert({"a": 1}) + v_array = tvm_ffi.convert([1, 2, 3]) + + obj0 = tvm_ffi.testing.create_object( + "testing.TestObjectDerived", v_i64=20, v_map=v_map, v_array=v_array + ) + assert obj0.v_map.same_as(v_map) + assert obj0.v_array.same_as(v_array) + assert obj0.v_i64 == 20 + assert obj0.v_f64 == 10.0 + assert obj0.v_str == "hello" + + obj0.v_i64 = 21 + assert obj0.v_i64 == 21