diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index f099898b158d..b4f59526a900 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -156,6 +156,36 @@ typedef enum { /*! \brief Handle to Object from C API's pov */ typedef void* TVMFFIObjectHandle; +/*! + * \brief bitmask of the object deleter flag. + */ +#ifdef __cplusplus +enum TVMFFIObjectDeleterFlagBitMask : int32_t { +#else +typedef enum { +#endif + /*! + * \brief deleter action when strong reference count becomes zero. + * Need to call destructor of the object but not free the memory block. + */ + kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0, + /*! + * \brief deleter action when weak reference count becomes zero. + * Need to free the memory block. + */ + kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1, + /*! + * \brief deleter action when both strong and weak reference counts become zero. + * \note This is the most common case. + */ + kTVMFFIObjectDeleterFlagBitMaskBoth = + (kTVMFFIObjectDeleterFlagBitMaskStrong | kTVMFFIObjectDeleterFlagBitMaskWeak), +#ifdef __cplusplus +}; +#else +} TVMFFIObjectDeleterFlagBitMask; +#endif + /*! * \brief C-based type of all FFI object header that allocates on heap. * \note TVMFFIObject and TVMFFIAny share the common type_index header @@ -166,11 +196,22 @@ typedef struct TVMFFIObject { * \note The type index of Object and Any are shared in FFI. */ int32_t type_index; - /*! \brief Reference counter of the object. */ - int32_t ref_counter; + /*! + * \brief Weak reference counter of the object, for compatiblity with weak_ptr design. + * \note Use u32 to ensure that overall object stays within 24-byte boundary, usually + * manipulation of weak counter is less common than strong counter. + */ + uint32_t weak_ref_count; + /*! \brief Strong reference counter of the object. */ + uint64_t strong_ref_count; union { - /*! \brief Deleter to be invoked when reference counter goes to zero. */ - void (*deleter)(struct TVMFFIObject* self); + /*! + * \brief Deleter to be invoked when strong reference counter goes to zero. + * \param self The self object handle. + * \param flags The flags to indicate deletion behavior. + * \sa TVMFFIObjectDeleterFlagBitMask + */ + void (*deleter)(struct TVMFFIObject* self, int flags); /*! * \brief auxilary field to TVMFFIObject is always 8 bytes aligned. * \note This helps us to ensure cross platform compatibility. @@ -307,13 +348,19 @@ typedef struct { // Section: Basic object API //------------------------------------------------------------ /*! - * \brief Free an object handle by decreasing reference + * \brief Increas the strong reference count of an object handle + * \param obj The object handle. + * \note Internally we increase the reference counter of the object. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj); + +/*! + * \brief Free an object handle by decreasing strong reference * \param obj The object handle. - * \note Internally we decrease the reference counter of the object. - * The object will be freed when every reference to the object are removed. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); +TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); /*! * \brief Convert type key to type index. @@ -470,7 +517,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* * \param dtype The DLDataType to convert. * \param out The output string. * \return 0 when success, nonzero when failure happens -* \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. +* \note out is a String object that needs to be freed by the caller via TVMFFIObjectDecRef. The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. * \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index 02537df79cb4..533d0004274f 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -33,7 +33,7 @@ namespace tvm { namespace ffi { /*! \brief Deleter function for obeject */ -typedef void (*FObjectDeleter)(TVMFFIObject* obj); +typedef void (*FObjectDeleter)(TVMFFIObject* obj, int flags); /*! * \brief Allocate an object using default allocator. @@ -75,7 +75,8 @@ class ObjAllocatorBase { static_assert(std::is_base_of::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->ref_counter = 1; + ffi_ptr->strong_ref_count = 1; + ffi_ptr->weak_ref_count = 1; ffi_ptr->type_index = T::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); @@ -96,7 +97,8 @@ class ObjAllocatorBase { ArrayType* ptr = Handler::New(static_cast(this), num_elems, std::forward(args)...); TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->ref_counter = 1; + ffi_ptr->strong_ref_count = 1; + ffi_ptr->weak_ref_count = 1; ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); @@ -136,14 +138,18 @@ class SimpleObjAllocator : public ObjAllocatorBase { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(TVMFFIObject* objptr) { + static void Deleter_(TVMFFIObject* objptr, int flags) { T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); - // It is important to do tptr->T::~T(), - // so that we explicitly call the specific destructor - // instead of tptr->~T(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->T::~T(); - delete reinterpret_cast(tptr); + if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + } + if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { + delete reinterpret_cast(tptr); + } } }; @@ -182,15 +188,19 @@ class SimpleObjAllocator : public ObjAllocatorBase { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(TVMFFIObject* objptr) { + static void Deleter_(TVMFFIObject* objptr, int flags) { ArrayType* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(objptr); - // It is important to do tptr->ArrayType::~ArrayType(), - // so that we explicitly call the specific destructor - // instead of tptr->~ArrayType(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->ArrayType::~ArrayType(); - StorageType* p = reinterpret_cast(tptr); - delete[] p; + if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + } + if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { + StorageType* p = reinterpret_cast(tptr); + delete[] p; + } } }; }; diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index cf282a6e2744..cc5ee8d94585 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -143,7 +143,8 @@ class Object { public: Object() { - header_.ref_counter = 0; + header_.strong_ref_count = 0; + header_.weak_ref_count = 0; header_.deleter = nullptr; } /*! @@ -197,9 +198,9 @@ class Object { int32_t use_count() const { // only need relaxed load of counters #ifdef _MSC_VER - return (reinterpret_cast(&header_.ref_counter))[0]; // NOLINT(*) + return (reinterpret_cast(&header_.strong_ref_count))[0]; // NOLINT(*) #else - return __atomic_load_n(&(header_.ref_counter), __ATOMIC_RELAXED); + return __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED); #endif } @@ -230,33 +231,121 @@ class Object { static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } private: - /*! \brief increase reference count */ + /*! \brief increase strong reference count, the caller must already hold a strong reference */ void IncRef() { #ifdef _MSC_VER - _InterlockedIncrement(reinterpret_cast(&header_.ref_counter)); // NOLINT(*) + _InterlockedIncrement64( + reinterpret_cast(&header_.strong_ref_count)); // NOLINT(*) #else - __atomic_fetch_add(&(header_.ref_counter), 1, __ATOMIC_RELAXED); + __atomic_fetch_add(&(header_.strong_ref_count), 1, __ATOMIC_RELAXED); +#endif + } + /*! + * \brief Try to lock the object to increase the strong reference count, + * the caller must already hold a strong reference. + * \return whether the lock call is successful and object is still alive. + */ + bool TryPromoteWeakPtr() { +#ifdef _MSC_VER + uint64_t old_count = + (reinterpret_cast(&header_.strong_ref_count))[0]; // NOLINT(*) + while (old_count > 0) { + uint64_t new_count = old_count + 1; + uint64_t old_count_loaded = _InterlockedCompareExchange64( + reinterpret_cast(&header_.strong_ref_count), new_count, old_count); + if (old_count == old_count_loaded) { + return true; + } + old_count = old_count_loaded; + } + return false; +#else + uint64_t old_count = __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED); + while (old_count > 0) { + // must do CAS to ensure that we are the only one that increases the reference count + // avoid condition when two threads tries to promote weak to strong at same time + // or when strong deletion happens between the load and the CAS + uint64_t new_count = old_count + 1; + if (__atomic_compare_exchange_n(&(header_.strong_ref_count), &old_count, new_count, true, + __ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) { + return true; + } + } + return false; +#endif + } + + /*! \brief increase weak reference count */ + void IncWeakRef() { +#ifdef _MSC_VER + _InterlockedIncrement(reinterpret_cast(&header_.weak_ref_count)); // NOLINT(*) +#else + __atomic_fetch_add(&(header_.weak_ref_count), 1, __ATOMIC_RELAXED); #endif } - /*! \brief decrease reference count and delete the object */ + /*! \brief decrease strong reference count and delete the object */ void DecRef() { #ifdef _MSC_VER - if (_InterlockedDecrement( // - reinterpret_cast(&header_.ref_counter)) == 0) { // NOLINT(*) + // use simpler impl in windows to ensure correctness + if (_InterlockedDecrement64( // + reinterpret_cast(&header_.strong_ref_count)) == 0) { // NOLINT(*) // full barrrier is implicit in InterlockedDecrement if (header_.deleter != nullptr) { - header_.deleter(&(this->header_)); + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); + } + if (_InterlockedDecrement( // + reinterpret_cast(&header_.weak_ref_count)) == 0) { // NOLINT(*) + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } } } #else // first do a release, note we only need to acquire for deleter - if (__atomic_fetch_sub(&(header_.ref_counter), 1, __ATOMIC_RELEASE) == 1) { - // only acquire when we need to call deleter - // in this case we need to ensure all previous writes are visible + if (__atomic_fetch_sub(&(header_.strong_ref_count), 1, __ATOMIC_RELEASE) == 1) { + if (__atomic_load_n(&(header_.weak_ref_count), __ATOMIC_RELAXED) == 1) { + // common case, we need to delete both the object and the memory block + // only acquire when we need to call deleter + __atomic_thread_fence(__ATOMIC_ACQUIRE); + if (header_.deleter != nullptr) { + // call deleter once + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth); + } + } else { + // Slower path: there is still a weak reference left + __atomic_thread_fence(__ATOMIC_ACQUIRE); + // call destructor first, then decrease weak reference count + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); + } + // now decrease weak reference count + if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) == 1) { + __atomic_thread_fence(__ATOMIC_ACQUIRE); + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } + } + } +#endif + } + + /*! \brief decrease weak reference count */ + void DecWeakRef() { +#ifdef _MSC_VER + if (_InterlockedDecrement( // + reinterpret_cast(&header_.weak_ref_count)) == 0) { // NOLINT(*) + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } +#else + // now decrease weak reference count + if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) == 1) { __atomic_thread_fence(__ATOMIC_ACQUIRE); if (header_.deleter != nullptr) { - header_.deleter(&(this->header_)); + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); } } #endif @@ -265,6 +354,8 @@ class Object { // friend classes template friend class ObjectPtr; + template + friend class WeakObjectPtr; friend struct tvm::ffi::details::ObjectUnsafe; }; @@ -402,6 +493,148 @@ class ObjectPtr { friend struct ObjectPtrHash; template friend class ObjectPtr; + template + friend class WeakObjectPtr; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +/*! + * \brief A custom smart pointer for Object. + * \tparam T the content data type. + * \sa make_object + */ +template +class WeakObjectPtr { + public: + /*! \brief default constructor */ + WeakObjectPtr() {} + /*! \brief default constructor */ + WeakObjectPtr(std::nullptr_t) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other The value to be moved + */ + WeakObjectPtr(const WeakObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.data_) {} + + /*! + * \brief copy constructor + * \param other The value to be moved + */ + WeakObjectPtr(const ObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.get()) {} + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(const WeakObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(const ObjectPtr& other) // NOLINT(*) + : WeakObjectPtr(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + WeakObjectPtr(WeakObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + other.data_ = nullptr; + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(WeakObjectPtr&& other) // NOLINT(*) + : data_(other.data_) { + static_assert(std::is_base_of::value, + "can only assign of child class ObjectPtr to parent"); + other.data_ = nullptr; + } + /*! \brief destructor */ + ~WeakObjectPtr() { this->reset(); } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(WeakObjectPtr& other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + /*! + * \brief copy assignment + * \param other The value to be assigned. + * \return reference to self. + */ + WeakObjectPtr& operator=(const WeakObjectPtr& other) { // NOLINT(*) + // takes in plane operator to enable copy elison. + // copy-and-swap idiom + WeakObjectPtr(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief move assignment + * \param other The value to be assigned. + * \return reference to self. + */ + WeakObjectPtr& operator=(WeakObjectPtr&& other) { // NOLINT(*) + // copy-and-swap idiom + WeakObjectPtr(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + /*! \return The internal object pointer if the object is still alive, otherwise nullptr */ + ObjectPtr lock() const { + if (data_ != nullptr && data_->TryPromoteWeakPtr()) { + ObjectPtr ret; + // we already increase the reference count, so we don't need to do it again + ret.data_ = data_; + return ret; + } + return nullptr; + } + + /*! \brief reset the content of ptr to be nullptr */ + void reset() { + if (data_ != nullptr) { + data_->DecWeakRef(); + data_ = nullptr; + } + } + + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } + + /*! \return whether the pointer is nullptr */ + bool expired() const { return data_ == nullptr || data_->use_count() == 0; } + + private: + /*! \brief internal pointer field */ + Object* data_{nullptr}; + + /*! + * \brief constructor from Object + * \param data The data pointer + */ + explicit WeakObjectPtr(Object* data) : data_(data) { + if (data_ != nullptr) { + data_->IncWeakRef(); + } + } + + template + friend class WeakObjectPtr; friend struct tvm::ffi::details::ObjectUnsafe; }; diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index b019935a6cc8..9cdb2b933894 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -472,7 +472,7 @@ struct TypeTraits : public TypeTraitsBase { } else if (src->type_index == TypeIndex::kTVMFFINDArray) { // Conversion from NDArray pointer to DLTensor // based on the assumption that NDArray always follows the TVMFFIObject header - static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 8 bytes"); + static_assert(sizeof(TVMFFIObject) == 24); return reinterpret_cast(reinterpret_cast(src->v_obj) + sizeof(TVMFFIObject)); } diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml index 8ed9e275e2b3..083a60fc3631 100644 --- a/ffi/pyproject.toml +++ b/ffi/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "apache-tvm-ffi" -version = "0.1.0a5" +version = "0.1.0a6" description = "tvm ffi" authors = [{ name = "TVM FFI team" }] diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index 14b3d97f5260..4a47efd773d9 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -171,7 +171,7 @@ cdef extern from "tvm/ffi/c_api.h": const TVMFFIMethodInfo* methods const TVMFFITypeMetadata* metadata - int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil + int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) nogil diff --git a/ffi/python/tvm_ffi/cython/dtype.pxi b/ffi/python/tvm_ffi/cython/dtype.pxi index 279b17f8c83c..d9e20b77f3a8 100644 --- a/ffi/python/tvm_ffi/cython/dtype.pxi +++ b/ffi/python/tvm_ffi/cython/dtype.pxi @@ -104,7 +104,7 @@ cdef class DataType: bytes_ptr = TVMFFIBytesGetByteArrayPtr(temp_any.v_obj) res = py_str(PyBytes_FromStringAndSize(bytes_ptr.data, bytes_ptr.size)) - CHECK_CALL(TVMFFIObjectFree(temp_any.v_obj)) + CHECK_CALL(TVMFFIObjectDecRef(temp_any.v_obj)) return res diff --git a/ffi/python/tvm_ffi/cython/object.pxi b/ffi/python/tvm_ffi/cython/object.pxi index dad6bee51b34..1203f0c68289 100644 --- a/ffi/python/tvm_ffi/cython/object.pxi +++ b/ffi/python/tvm_ffi/cython/object.pxi @@ -78,7 +78,7 @@ cdef class Object: def __dealloc__(self): if self.chandle != NULL: - CHECK_CALL(TVMFFIObjectFree(self.chandle)) + CHECK_CALL(TVMFFIObjectDecRef(self.chandle)) self.chandle = NULL def __ctypes_handle__(self): diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 61107cb63ff7..f96636fd4994 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -388,12 +388,18 @@ class TypeTable { } // namespace ffi } // namespace tvm -int TVMFFIObjectFree(TVMFFIObjectHandle handle) { +int TVMFFIObjectDecRef(TVMFFIObjectHandle handle) { TVM_FFI_SAFE_CALL_BEGIN(); tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle); TVM_FFI_SAFE_CALL_END(); } +int TVMFFIObjectIncRef(TVMFFIObjectHandle handle) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(handle); + TVM_FFI_SAFE_CALL_END(); +} + int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) { TVM_FFI_SAFE_CALL_BEGIN(); out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key); diff --git a/ffi/tests/cpp/test_c_ffi_abi.cc b/ffi/tests/cpp/test_c_ffi_abi.cc index 1efceef2971a..e6c6116edd8c 100644 --- a/ffi/tests/cpp/test_c_ffi_abi.cc +++ b/ffi/tests/cpp/test_c_ffi_abi.cc @@ -25,7 +25,7 @@ TEST(ABIHeaderAlignment, Default) { TVMFFIObject value; value.type_index = 10; EXPECT_EQ(reinterpret_cast(&value)->type_index, 10); - static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 16 bytes"); + static_assert(sizeof(TVMFFIObject) == 24); } } // namespace diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc index 4b53a70b42a2..f6bedcb6f371 100644 --- a/ffi/tests/cpp/test_object.cc +++ b/ffi/tests/cpp/test_object.cc @@ -103,4 +103,123 @@ TEST(Object, CAPIAccessor) { int32_t type_index = TVMFFIObjectGetTypeIndex(obj); EXPECT_EQ(type_index, TIntObj::RuntimeTypeIndex()); } + +TEST(Object, WeakObjectPtr) { + // Test basic construction from ObjectPtr + ObjectPtr strong_ptr = make_object(42); + WeakObjectPtr weak_ptr(strong_ptr); + + EXPECT_EQ(strong_ptr.use_count(), 1); + EXPECT_FALSE(weak_ptr.expired()); + EXPECT_EQ(weak_ptr.use_count(), 1); + + // Test lock() when object is still alive + ObjectPtr locked_ptr = weak_ptr.lock(); + EXPECT_TRUE(locked_ptr != nullptr); + EXPECT_EQ(locked_ptr->value, 42); + EXPECT_EQ(strong_ptr.use_count(), 2); + EXPECT_EQ(weak_ptr.use_count(), 2); + + // Test lock() when object is expired + strong_ptr.reset(); + locked_ptr.reset(); + EXPECT_TRUE(weak_ptr.expired()); + EXPECT_EQ(weak_ptr.use_count(), 0); + + ObjectPtr expired_lock = weak_ptr.lock(); + EXPECT_TRUE(expired_lock == nullptr); +} + +TEST(Object, WeakObjectPtrAssignment) { + // Test copy construction + ObjectPtr new_strong = make_object(100); + WeakObjectPtr weak1(new_strong); + WeakObjectPtr weak2(weak1); + + EXPECT_EQ(new_strong.use_count(), 1); + EXPECT_FALSE(weak1.expired()); + EXPECT_FALSE(weak2.expired()); + EXPECT_EQ(weak1.use_count(), 1); + EXPECT_EQ(weak2.use_count(), 1); + + // Test move construction + WeakObjectPtr weak3(std::move(weak1)); + EXPECT_TRUE(weak1.expired()); // weak1 should be moved from + EXPECT_FALSE(weak3.expired()); + EXPECT_EQ(weak3.use_count(), 1); + + // Test assignment + WeakObjectPtr weak4; + weak4 = weak2; + EXPECT_FALSE(weak2.expired()); + EXPECT_FALSE(weak4.expired()); + EXPECT_EQ(weak2.use_count(), 1); + EXPECT_EQ(weak4.use_count(), 1); + + // Test move assignment + WeakObjectPtr weak5; + weak5 = std::move(weak2); + EXPECT_TRUE(weak2.expired()); // weak2 should be moved from + EXPECT_FALSE(weak5.expired()); + EXPECT_EQ(weak5.use_count(), 1); + + // Test reset() + weak3.reset(); + EXPECT_TRUE(weak3.expired()); + EXPECT_EQ(weak3.use_count(), 0); + + // Test swap() + ObjectPtr strong_a = make_object(200); + ObjectPtr strong_b = make_object(300); + WeakObjectPtr weak_a(strong_a); + WeakObjectPtr weak_b(strong_b); + + weak_a.swap(weak_b); + EXPECT_EQ(weak_a.lock()->value, 300); + EXPECT_EQ(weak_b.lock()->value, 200); + + // Test construction from nullptr + WeakObjectPtr null_weak(nullptr); + EXPECT_TRUE(null_weak.expired()); + EXPECT_EQ(null_weak.use_count(), 0); + EXPECT_TRUE(null_weak.lock() == nullptr); + + // Test inheritance compatibility + ObjectPtr number_ptr = make_object(500); + WeakObjectPtr number_weak(number_ptr); + + EXPECT_FALSE(number_weak.expired()); + EXPECT_EQ(number_weak.use_count(), 1); + + // Test that weak references don't prevent object deletion + ObjectPtr temp_strong = make_object(999); + WeakObjectPtr temp_weak(temp_strong); + + EXPECT_FALSE(temp_weak.expired()); + temp_strong.reset(); + EXPECT_TRUE(temp_weak.expired()); + EXPECT_TRUE(temp_weak.lock() == nullptr); + + // Test multiple weak references + ObjectPtr multi_strong = make_object(777); + WeakObjectPtr multi_weak1(multi_strong); + WeakObjectPtr multi_weak2(multi_strong); + WeakObjectPtr multi_weak3(multi_strong); + + EXPECT_EQ(multi_strong.use_count(), 1); + EXPECT_FALSE(multi_weak1.expired()); + EXPECT_FALSE(multi_weak2.expired()); + EXPECT_FALSE(multi_weak3.expired()); + + // All weak references should be able to lock + ObjectPtr lock1 = multi_weak1.lock(); + ObjectPtr lock2 = multi_weak2.lock(); + ObjectPtr lock3 = multi_weak3.lock(); + + EXPECT_EQ(multi_strong.use_count(), 4); + EXPECT_EQ(lock1->value, 777); + EXPECT_EQ(lock2->value, 777); + EXPECT_EQ(lock3->value, 777); +} + } // namespace diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index 5db3e279cf3f..9b50fb6a4914 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -236,7 +236,7 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { } case TypeIndex::kTVMFFIBytes: { jobject ret = newTVMValueBytes(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); - TVMFFIObjectFree(value.v_obj); + TVMFFIObjectDecRef(value.v_obj); return ret; } default: { diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 3ebe7fddfa8f..b512ec8775bd 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -322,7 +322,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionSetGlobal(JNIEn // Module JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIObjectFree(JNIEnv* env, jobject obj, jlong jhandle) { - return TVMFFIObjectFree(reinterpret_cast(jhandle)); + return TVMFFIObjectDecRef(reinterpret_cast(jhandle)); } // NDArray diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7477fe86363d..e6c6e9aa0275 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -299,10 +299,12 @@ PrimFunc MakePackedAPI(PrimFunc func) { tvm::tir::StringImm(msg.str()), nop)); // if type_index is NDArray, we need to add the offset of the DLTensor header // which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor* + const int64_t object_cell_offset = sizeof(TVMFFIObject); + static_assert(object_cell_offset == 24); arg_value = f_load_arg_value(param.dtype(), i); PrimExpr handle_from_ndarray = Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), - {arg_value, IntImm(DataType::Int(32), 16)}); + {arg_value, IntImm(DataType::Int(32), object_cell_offset)}); arg_value = Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value); } else if (dtype.is_bool()) { diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index d2ecf4b944b0..9836fbfda530 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -41,7 +41,7 @@ export const enum SizeOf { TVMFFIAny = 8 * 2, DLDataType = I32, DLDevice = I32 + I32, - ObjectHeader = 8 * 2, + ObjectHeader = 8 * 3, } //---------------The new TVM FFI--------------- @@ -142,9 +142,9 @@ export type FTVMFFIWasmFunctionCreate = ( export type FTVMFFIWasmFunctionDeleter = (self: Pointer) => void; /** - * int TVMFFIObjectFree(TVMFFIObjectHandle obj); + * int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); */ -export type FTVMFFIObjectFree = (obj: Pointer) => number; +export type FTVMFFIObjectDecRef = (obj: Pointer) => number; /** * int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 071b2eed68e4..3720b1873eee 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -450,7 +450,7 @@ export class TVMObject implements Disposable { dispose(): void { if (this.handle != 0) { this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(this.handle) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(this.handle) ); this.handle = 0; } @@ -2253,7 +2253,7 @@ export class Instance implements Disposable { const strObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) ); return result; } @@ -2264,7 +2264,7 @@ export class Instance implements Disposable { const strObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) ); return result; } @@ -2275,7 +2275,7 @@ export class Instance implements Disposable { const bytesObjPtr = this.memory.loadPointer(valuePtr); const result = this.memory.loadByteArrayAsBytes(bytesObjPtr + SizeOf.ObjectHeader); this.lib.checkCall( - (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(bytesObjPtr) + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(bytesObjPtr) ); return result; }