Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ if (TVM_FFI_BUILD_PYTHON_MODULE)
Python_add_library(tvm_ffi_cython MODULE "${core_cpp}" WITH_SOABI)
set_target_properties(tvm_ffi_cython PROPERTIES OUTPUT_NAME "core")
endif()
target_include_directories(tvm_ffi_cython PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython)
target_compile_features(tvm_ffi_cython PRIVATE cxx_std_17)
target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_header)
target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_shared)
Expand Down
65 changes: 50 additions & 15 deletions ffi/include/tvm/ffi/container/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <tvm/ffi/error.h>
#include <tvm/ffi/type_traits.h>

#include <atomic>
#include <memory>
#include <utility>

namespace tvm {
Expand Down Expand Up @@ -123,18 +125,26 @@ class TensorObj : public Object, public DLTensor {
static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor;
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object);
/// \endcond

~TensorObj() {
// deleting the cached dl managed tensor versioned
// need to acquire the value in case it is released by another thread
DLManagedTensorVersioned* cached =
cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
if (cached != nullptr) {
delete cached;
}
}
/*!
* \brief Move a Tensor to a DLPack managed tensor.
* \return The converted DLPack managed tensor.
*/
DLManagedTensor* ToDLPack() const {
TensorObj* self = const_cast<TensorObj*>(this);
DLManagedTensor* ret = new DLManagedTensor();
TensorObj* from = const_cast<TensorObj*>(this);
ret->dl_tensor = *static_cast<DLTensor*>(from);
ret->manager_ctx = from;
ret->dl_tensor = *static_cast<DLTensor*>(self);
ret->manager_ctx = self;
ret->deleter = DLManagedTensorDeleter;
details::ObjectUnsafe::IncRefObjectHandle(from);
details::ObjectUnsafe::IncRefObjectHandle(self);
return ret;
}

Expand All @@ -143,23 +153,49 @@ class TensorObj : public Object, public DLTensor {
* \return The converted DLPack managed tensor.
*/
DLManagedTensorVersioned* ToDLPackVersioned() const {
DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
TensorObj* from = const_cast<TensorObj*>(this);
ret->version.major = DLPACK_MAJOR_VERSION;
ret->version.minor = DLPACK_MINOR_VERSION;
ret->dl_tensor = *static_cast<DLTensor*>(from);
ret->manager_ctx = from;
ret->deleter = DLManagedTensorVersionedDeleter;
ret->flags = 0;
// if cache is set, directly return it
// we need to use acquire to ensure that write to DLManagedTensorVersioned
// from another thread is visible to this thread.
DLManagedTensorVersioned* cached =
cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
// if cache is not set, create a new one
if (cached == nullptr) {
DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
ret->version.major = DLPACK_MAJOR_VERSION;
ret->version.minor = DLPACK_MINOR_VERSION;
ret->dl_tensor = *static_cast<DLTensor*>(from);
ret->manager_ctx = from;
ret->deleter = EmbeddedDLManagedTensorVersionedDeleter;
ret->flags = 0;
DLManagedTensorVersioned* expected = nullptr;
// success set must release the new value to all other threads
// failure set must acquire, since the expected value is now coming
// from another thread that released this value
if (std::atomic_compare_exchange_strong_explicit(&cached_dl_managed_tensor_versioned_,
&expected, ret, std::memory_order_release,
std::memory_order_acquire)) {
// set is succes
cached = ret;
} else {
// delete the ret value as another thread raced to set this one first
delete ret;
cached = expected;
}
// at this point, cached is the value that officially set to the field
}
// inc the ref count of the from object
details::ObjectUnsafe::IncRefObjectHandle(from);
return ret;
return cached;
}

protected:
/*! \brief Internal data to back returning shape. */
Optional<Shape> shape_data_;
/*! \brief Internal data to back returning strides. */
Optional<Shape> strides_data_;
/*! \brief cached data to back returning DLManagedTensorVersioned. */
mutable std::atomic<DLManagedTensorVersioned*> cached_dl_managed_tensor_versioned_ = nullptr;

/*!
* \brief Deleter for DLManagedTensor.
Expand All @@ -175,10 +211,9 @@ class TensorObj : public Object, public DLTensor {
* \brief Deleter for DLManagedTensorVersioned.
* \param tensor The DLManagedTensorVersioned to be deleted.
*/
static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) {
static void EmbeddedDLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) {
TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
details::ObjectUnsafe::DecRefObjectHandle(obj);
delete tensor;
}

friend class Tensor;
Expand Down
55 changes: 54 additions & 1 deletion ffi/python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ cdef extern from "dlpack/dlpack.h":

ctypedef struct DLManagedTensorVersioned:
DLPackVersion version
DLManagedTensor dl_tensor
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensorVersioned* self)
uint64_t flags
Expand Down Expand Up @@ -195,6 +195,7 @@ cdef extern from "tvm/ffi/c_api.h":
const TVMFFITypeMetadata* metadata

int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil
int TVMFFIObjectIncRef(TVMFFIObjectHandle obj) nogil
int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index,
void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil
Expand Down Expand Up @@ -243,6 +244,58 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
TVMFFIStreamHandle* opt_out_original_stream) nogil


cdef extern from "tvm_ffi_python_helpers.h":
# no need to expose fields of the call context
ctypedef struct TVMFFIPyCallContext:
int device_type
int device_id
TVMFFIStreamHandle stream

# setter data structure
ctypedef int (*DLPackPyObjectCExporter)(
void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
) except -1

ctypedef struct TVMFFIPyArgSetter:
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1
DLPackPyObjectCExporter dlpack_c_exporter

ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1
# The main call function
int TVMFFIPyFuncCall(
TVMFFIPyArgSetterFactory setter_factory,
void* chandle,
PyObject* py_arg_tuple,
TVMFFIAny* result,
int* c_api_ret_code
) except -1

int TVMFFIPyCallFieldSetter(
TVMFFIPyArgSetterFactory setter_factory,
TVMFFIFieldSetter field_setter,
void* field_ptr,
PyObject* py_arg,
int* c_api_ret_code
) except -1

int TVMFFIPyPyObjectToFFIAny(
TVMFFIPyArgSetterFactory setter_factory,
PyObject* py_arg,
TVMFFIAny* out,
int* c_api_ret_code
) except -1

size_t TVMFFIPyGetDispatchMapSize() noexcept

void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, TVMFFIObjectHandle arg) noexcept
void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) noexcept
# the predefined setters for common POD types
int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1


cdef class ByteArrayArg:
cdef TVMFFIByteArray cdata
cdef object py_data
Expand Down
Loading
Loading