Skip to content

Commit 8f658cc

Browse files
authored
[FFI][REFACTOR] Refactor python ffi call mechanism for perf (#18302)
This PR refactors python ffi call mechanism. Previously the argument setting can become an as things can be sensitive to the if checking order. This PR refactors the calling to leverage a C++ based dispatcher where each dispatch functor can be registered from Cython.
1 parent da7b68d commit 8f658cc

File tree

8 files changed

+986
-238
lines changed

8 files changed

+986
-238
lines changed

ffi/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ if (TVM_FFI_BUILD_PYTHON_MODULE)
215215
Python_add_library(tvm_ffi_cython MODULE "${core_cpp}" WITH_SOABI)
216216
set_target_properties(tvm_ffi_cython PROPERTIES OUTPUT_NAME "core")
217217
endif()
218+
target_include_directories(tvm_ffi_cython PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython)
218219
target_compile_features(tvm_ffi_cython PRIVATE cxx_std_17)
219220
target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_header)
220221
target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_shared)

ffi/include/tvm/ffi/container/tensor.h

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include <tvm/ffi/error.h>
3131
#include <tvm/ffi/type_traits.h>
3232

33+
#include <atomic>
34+
#include <memory>
3335
#include <utility>
3436

3537
namespace tvm {
@@ -123,18 +125,26 @@ class TensorObj : public Object, public DLTensor {
123125
static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor;
124126
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object);
125127
/// \endcond
126-
128+
~TensorObj() {
129+
// deleting the cached dl managed tensor versioned
130+
// need to acquire the value in case it is released by another thread
131+
DLManagedTensorVersioned* cached =
132+
cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
133+
if (cached != nullptr) {
134+
delete cached;
135+
}
136+
}
127137
/*!
128138
* \brief Move a Tensor to a DLPack managed tensor.
129139
* \return The converted DLPack managed tensor.
130140
*/
131141
DLManagedTensor* ToDLPack() const {
142+
TensorObj* self = const_cast<TensorObj*>(this);
132143
DLManagedTensor* ret = new DLManagedTensor();
133-
TensorObj* from = const_cast<TensorObj*>(this);
134-
ret->dl_tensor = *static_cast<DLTensor*>(from);
135-
ret->manager_ctx = from;
144+
ret->dl_tensor = *static_cast<DLTensor*>(self);
145+
ret->manager_ctx = self;
136146
ret->deleter = DLManagedTensorDeleter;
137-
details::ObjectUnsafe::IncRefObjectHandle(from);
147+
details::ObjectUnsafe::IncRefObjectHandle(self);
138148
return ret;
139149
}
140150

@@ -143,23 +153,49 @@ class TensorObj : public Object, public DLTensor {
143153
* \return The converted DLPack managed tensor.
144154
*/
145155
DLManagedTensorVersioned* ToDLPackVersioned() const {
146-
DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
147156
TensorObj* from = const_cast<TensorObj*>(this);
148-
ret->version.major = DLPACK_MAJOR_VERSION;
149-
ret->version.minor = DLPACK_MINOR_VERSION;
150-
ret->dl_tensor = *static_cast<DLTensor*>(from);
151-
ret->manager_ctx = from;
152-
ret->deleter = DLManagedTensorVersionedDeleter;
153-
ret->flags = 0;
157+
// if cache is set, directly return it
158+
// we need to use acquire to ensure that write to DLManagedTensorVersioned
159+
// from another thread is visible to this thread.
160+
DLManagedTensorVersioned* cached =
161+
cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire);
162+
// if cache is not set, create a new one
163+
if (cached == nullptr) {
164+
DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
165+
ret->version.major = DLPACK_MAJOR_VERSION;
166+
ret->version.minor = DLPACK_MINOR_VERSION;
167+
ret->dl_tensor = *static_cast<DLTensor*>(from);
168+
ret->manager_ctx = from;
169+
ret->deleter = EmbeddedDLManagedTensorVersionedDeleter;
170+
ret->flags = 0;
171+
DLManagedTensorVersioned* expected = nullptr;
172+
// success set must release the new value to all other threads
173+
// failure set must acquire, since the expected value is now coming
174+
// from another thread that released this value
175+
if (std::atomic_compare_exchange_strong_explicit(&cached_dl_managed_tensor_versioned_,
176+
&expected, ret, std::memory_order_release,
177+
std::memory_order_acquire)) {
178+
// set is succes
179+
cached = ret;
180+
} else {
181+
// delete the ret value as another thread raced to set this one first
182+
delete ret;
183+
cached = expected;
184+
}
185+
// at this point, cached is the value that officially set to the field
186+
}
187+
// inc the ref count of the from object
154188
details::ObjectUnsafe::IncRefObjectHandle(from);
155-
return ret;
189+
return cached;
156190
}
157191

158192
protected:
159193
/*! \brief Internal data to back returning shape. */
160194
Optional<Shape> shape_data_;
161195
/*! \brief Internal data to back returning strides. */
162196
Optional<Shape> strides_data_;
197+
/*! \brief cached data to back returning DLManagedTensorVersioned. */
198+
mutable std::atomic<DLManagedTensorVersioned*> cached_dl_managed_tensor_versioned_ = nullptr;
163199

164200
/*!
165201
* \brief Deleter for DLManagedTensor.
@@ -175,10 +211,9 @@ class TensorObj : public Object, public DLTensor {
175211
* \brief Deleter for DLManagedTensorVersioned.
176212
* \param tensor The DLManagedTensorVersioned to be deleted.
177213
*/
178-
static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) {
214+
static void EmbeddedDLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) {
179215
TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
180216
details::ObjectUnsafe::DecRefObjectHandle(obj);
181-
delete tensor;
182217
}
183218

184219
friend class Tensor;

ffi/python/tvm_ffi/cython/base.pxi

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ cdef extern from "dlpack/dlpack.h":
7272

7373
ctypedef struct DLManagedTensorVersioned:
7474
DLPackVersion version
75-
DLManagedTensor dl_tensor
75+
DLTensor dl_tensor
7676
void* manager_ctx
7777
void (*deleter)(DLManagedTensorVersioned* self)
7878
uint64_t flags
@@ -195,6 +195,7 @@ cdef extern from "tvm/ffi/c_api.h":
195195
const TVMFFITypeMetadata* metadata
196196

197197
int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil
198+
int TVMFFIObjectIncRef(TVMFFIObjectHandle obj) nogil
198199
int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index,
199200
void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
200201
int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil
@@ -243,6 +244,58 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
243244
TVMFFIStreamHandle* opt_out_original_stream) nogil
244245

245246

247+
cdef extern from "tvm_ffi_python_helpers.h":
248+
# no need to expose fields of the call context
249+
ctypedef struct TVMFFIPyCallContext:
250+
int device_type
251+
int device_id
252+
TVMFFIStreamHandle stream
253+
254+
# setter data structure
255+
ctypedef int (*DLPackPyObjectCExporter)(
256+
void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
257+
) except -1
258+
259+
ctypedef struct TVMFFIPyArgSetter:
260+
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1
261+
DLPackPyObjectCExporter dlpack_c_exporter
262+
263+
ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1
264+
# The main call function
265+
int TVMFFIPyFuncCall(
266+
TVMFFIPyArgSetterFactory setter_factory,
267+
void* chandle,
268+
PyObject* py_arg_tuple,
269+
TVMFFIAny* result,
270+
int* c_api_ret_code
271+
) except -1
272+
273+
int TVMFFIPyCallFieldSetter(
274+
TVMFFIPyArgSetterFactory setter_factory,
275+
TVMFFIFieldSetter field_setter,
276+
void* field_ptr,
277+
PyObject* py_arg,
278+
int* c_api_ret_code
279+
) except -1
280+
281+
int TVMFFIPyPyObjectToFFIAny(
282+
TVMFFIPyArgSetterFactory setter_factory,
283+
PyObject* py_arg,
284+
TVMFFIAny* out,
285+
int* c_api_ret_code
286+
) except -1
287+
288+
size_t TVMFFIPyGetDispatchMapSize() noexcept
289+
290+
void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, TVMFFIObjectHandle arg) noexcept
291+
void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) noexcept
292+
# the predefined setters for common POD types
293+
int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
294+
int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
295+
int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
296+
int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1
297+
298+
246299
cdef class ByteArrayArg:
247300
cdef TVMFFIByteArray cdata
248301
cdef object py_data

0 commit comments

Comments
 (0)