3030#include < tvm/ffi/error.h>
3131#include < tvm/ffi/type_traits.h>
3232
33+ #include < atomic>
34+ #include < memory>
3335#include < utility>
3436
3537namespace tvm {
@@ -123,18 +125,26 @@ class TensorObj : public Object, public DLTensor {
123125 static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor ;
124126 TVM_FFI_DECLARE_OBJECT_INFO_STATIC (StaticTypeKey::kTVMFFITensor , TensorObj, Object);
125127 // / \endcond
126-
128+ ~TensorObj () {
129+ // deleting the cached dl managed tensor versioned
130+ // need to acquire the value in case it is released by another thread
131+ DLManagedTensorVersioned* cached =
132+ cached_dl_managed_tensor_versioned_.load (std::memory_order_acquire);
133+ if (cached != nullptr ) {
134+ delete cached;
135+ }
136+ }
127137 /* !
128138 * \brief Move a Tensor to a DLPack managed tensor.
129139 * \return The converted DLPack managed tensor.
130140 */
131141 DLManagedTensor* ToDLPack () const {
142+ TensorObj* self = const_cast <TensorObj*>(this );
132143 DLManagedTensor* ret = new DLManagedTensor ();
133- TensorObj* from = const_cast <TensorObj*>(this );
134- ret->dl_tensor = *static_cast <DLTensor*>(from);
135- ret->manager_ctx = from;
144+ ret->dl_tensor = *static_cast <DLTensor*>(self);
145+ ret->manager_ctx = self;
136146 ret->deleter = DLManagedTensorDeleter;
137- details::ObjectUnsafe::IncRefObjectHandle (from );
147+ details::ObjectUnsafe::IncRefObjectHandle (self );
138148 return ret;
139149 }
140150
@@ -143,23 +153,49 @@ class TensorObj : public Object, public DLTensor {
143153 * \return The converted DLPack managed tensor.
144154 */
145155 DLManagedTensorVersioned* ToDLPackVersioned () const {
146- DLManagedTensorVersioned* ret = new DLManagedTensorVersioned ();
147156 TensorObj* from = const_cast <TensorObj*>(this );
148- ret->version .major = DLPACK_MAJOR_VERSION;
149- ret->version .minor = DLPACK_MINOR_VERSION;
150- ret->dl_tensor = *static_cast <DLTensor*>(from);
151- ret->manager_ctx = from;
152- ret->deleter = DLManagedTensorVersionedDeleter;
153- ret->flags = 0 ;
157+ // if cache is set, directly return it
158+ // we need to use acquire to ensure that write to DLManagedTensorVersioned
159+ // from another thread is visible to this thread.
160+ DLManagedTensorVersioned* cached =
161+ cached_dl_managed_tensor_versioned_.load (std::memory_order_acquire);
162+ // if cache is not set, create a new one
163+ if (cached == nullptr ) {
164+ DLManagedTensorVersioned* ret = new DLManagedTensorVersioned ();
165+ ret->version .major = DLPACK_MAJOR_VERSION;
166+ ret->version .minor = DLPACK_MINOR_VERSION;
167+ ret->dl_tensor = *static_cast <DLTensor*>(from);
168+ ret->manager_ctx = from;
169+ ret->deleter = EmbeddedDLManagedTensorVersionedDeleter;
170+ ret->flags = 0 ;
171+ DLManagedTensorVersioned* expected = nullptr ;
172+ // success set must release the new value to all other threads
173+ // failure set must acquire, since the expected value is now coming
174+ // from another thread that released this value
175+ if (std::atomic_compare_exchange_strong_explicit (&cached_dl_managed_tensor_versioned_,
176+ &expected, ret, std::memory_order_release,
177+ std::memory_order_acquire)) {
178+ // set is succes
179+ cached = ret;
180+ } else {
181+ // delete the ret value as another thread raced to set this one first
182+ delete ret;
183+ cached = expected;
184+ }
185+ // at this point, cached is the value that officially set to the field
186+ }
187+ // inc the ref count of the from object
154188 details::ObjectUnsafe::IncRefObjectHandle (from);
155- return ret ;
189+ return cached ;
156190 }
157191
158192 protected:
159193 /* ! \brief Internal data to back returning shape. */
160194 Optional<Shape> shape_data_;
161195 /* ! \brief Internal data to back returning strides. */
162196 Optional<Shape> strides_data_;
197+ /* ! \brief cached data to back returning DLManagedTensorVersioned. */
198+ mutable std::atomic<DLManagedTensorVersioned*> cached_dl_managed_tensor_versioned_ = nullptr ;
163199
164200 /* !
165201 * \brief Deleter for DLManagedTensor.
@@ -175,10 +211,9 @@ class TensorObj : public Object, public DLTensor {
175211 * \brief Deleter for DLManagedTensorVersioned.
176212 * \param tensor The DLManagedTensorVersioned to be deleted.
177213 */
178- static void DLManagedTensorVersionedDeleter (DLManagedTensorVersioned* tensor) {
214+ static void EmbeddedDLManagedTensorVersionedDeleter (DLManagedTensorVersioned* tensor) {
179215 TensorObj* obj = static_cast <TensorObj*>(tensor->manager_ctx );
180216 details::ObjectUnsafe::DecRefObjectHandle (obj);
181- delete tensor;
182217 }
183218
184219 friend class Tensor ;
0 commit comments