From e5f101fff11f86cde06697b1e6ef182f71071a6e Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 8 Sep 2025 21:28:14 -0400 Subject: [PATCH] [FFI][REFACTOR] Refactor python ffi call mechanism for perf This PR refactors python ffi call mechanism. Previously the argument setting can become an as things can be sensitive to the if checking order. This PR refactors the calling to leverage a C++ based dispatcher where each dispatch functor can be registered from Cython. --- ffi/CMakeLists.txt | 1 + ffi/include/tvm/ffi/container/tensor.h | 65 ++- ffi/python/tvm_ffi/cython/base.pxi | 55 +- ffi/python/tvm_ffi/cython/function.pxi | 538 +++++++++++------- ffi/python/tvm_ffi/cython/tensor.pxi | 100 +++- .../tvm_ffi/cython/tvm_ffi_python_helpers.h | 447 +++++++++++++++ ffi/scripts/benchmark_dlpack.py | 16 + ffi/src/ffi/extra/testing.cc | 2 +- 8 files changed, 986 insertions(+), 238 deletions(-) create mode 100644 ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index 94395d234352..f927403cbde9 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -215,6 +215,7 @@ if (TVM_FFI_BUILD_PYTHON_MODULE) Python_add_library(tvm_ffi_cython MODULE "${core_cpp}" WITH_SOABI) set_target_properties(tvm_ffi_cython PROPERTIES OUTPUT_NAME "core") endif() + target_include_directories(tvm_ffi_cython PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython) target_compile_features(tvm_ffi_cython PRIVATE cxx_std_17) target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_header) target_link_libraries(tvm_ffi_cython PRIVATE tvm_ffi_shared) diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index 4d652e213fa6..5e20b7b51df2 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -30,6 +30,8 @@ #include #include +#include +#include #include namespace tvm { @@ -123,18 +125,26 @@ class TensorObj : public Object, public DLTensor { static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor; TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object); /// \endcond - + ~TensorObj() { + // deleting the cached dl managed tensor versioned + // need to acquire the value in case it is released by another thread + DLManagedTensorVersioned* cached = + cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire); + if (cached != nullptr) { + delete cached; + } + } /*! * \brief Move a Tensor to a DLPack managed tensor. * \return The converted DLPack managed tensor. */ DLManagedTensor* ToDLPack() const { + TensorObj* self = const_cast(this); DLManagedTensor* ret = new DLManagedTensor(); - TensorObj* from = const_cast(this); - ret->dl_tensor = *static_cast(from); - ret->manager_ctx = from; + ret->dl_tensor = *static_cast(self); + ret->manager_ctx = self; ret->deleter = DLManagedTensorDeleter; - details::ObjectUnsafe::IncRefObjectHandle(from); + details::ObjectUnsafe::IncRefObjectHandle(self); return ret; } @@ -143,16 +153,40 @@ class TensorObj : public Object, public DLTensor { * \return The converted DLPack managed tensor. */ DLManagedTensorVersioned* ToDLPackVersioned() const { - DLManagedTensorVersioned* ret = new DLManagedTensorVersioned(); TensorObj* from = const_cast(this); - ret->version.major = DLPACK_MAJOR_VERSION; - ret->version.minor = DLPACK_MINOR_VERSION; - ret->dl_tensor = *static_cast(from); - ret->manager_ctx = from; - ret->deleter = DLManagedTensorVersionedDeleter; - ret->flags = 0; + // if cache is set, directly return it + // we need to use acquire to ensure that write to DLManagedTensorVersioned + // from another thread is visible to this thread. + DLManagedTensorVersioned* cached = + cached_dl_managed_tensor_versioned_.load(std::memory_order_acquire); + // if cache is not set, create a new one + if (cached == nullptr) { + DLManagedTensorVersioned* ret = new DLManagedTensorVersioned(); + ret->version.major = DLPACK_MAJOR_VERSION; + ret->version.minor = DLPACK_MINOR_VERSION; + ret->dl_tensor = *static_cast(from); + ret->manager_ctx = from; + ret->deleter = EmbeddedDLManagedTensorVersionedDeleter; + ret->flags = 0; + DLManagedTensorVersioned* expected = nullptr; + // success set must release the new value to all other threads + // failure set must acquire, since the expected value is now coming + // from another thread that released this value + if (std::atomic_compare_exchange_strong_explicit(&cached_dl_managed_tensor_versioned_, + &expected, ret, std::memory_order_release, + std::memory_order_acquire)) { + // set is succes + cached = ret; + } else { + // delete the ret value as another thread raced to set this one first + delete ret; + cached = expected; + } + // at this point, cached is the value that officially set to the field + } + // inc the ref count of the from object details::ObjectUnsafe::IncRefObjectHandle(from); - return ret; + return cached; } protected: @@ -160,6 +194,8 @@ class TensorObj : public Object, public DLTensor { Optional shape_data_; /*! \brief Internal data to back returning strides. */ Optional strides_data_; + /*! \brief cached data to back returning DLManagedTensorVersioned. */ + mutable std::atomic cached_dl_managed_tensor_versioned_ = nullptr; /*! * \brief Deleter for DLManagedTensor. @@ -175,10 +211,9 @@ class TensorObj : public Object, public DLTensor { * \brief Deleter for DLManagedTensorVersioned. * \param tensor The DLManagedTensorVersioned to be deleted. */ - static void DLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) { + static void EmbeddedDLManagedTensorVersionedDeleter(DLManagedTensorVersioned* tensor) { TensorObj* obj = static_cast(tensor->manager_ctx); details::ObjectUnsafe::DecRefObjectHandle(obj); - delete tensor; } friend class Tensor; diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index efb2225453f5..08b01d424f1f 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -72,7 +72,7 @@ cdef extern from "dlpack/dlpack.h": ctypedef struct DLManagedTensorVersioned: DLPackVersion version - DLManagedTensor dl_tensor + DLTensor dl_tensor void* manager_ctx void (*deleter)(DLManagedTensorVersioned* self) uint64_t flags @@ -195,6 +195,7 @@ cdef extern from "tvm/ffi/c_api.h": const TVMFFITypeMetadata* metadata int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil + int TVMFFIObjectIncRef(TVMFFIObjectHandle obj) nogil int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, void (*deleter)(void*), TVMFFIObjectHandle* out) nogil int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil @@ -243,6 +244,58 @@ cdef extern from "tvm/ffi/extra/c_env_api.h": TVMFFIStreamHandle* opt_out_original_stream) nogil +cdef extern from "tvm_ffi_python_helpers.h": + # no need to expose fields of the call context + ctypedef struct TVMFFIPyCallContext: + int device_type + int device_id + TVMFFIStreamHandle stream + + # setter data structure + ctypedef int (*DLPackPyObjectCExporter)( + void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream + ) except -1 + + ctypedef struct TVMFFIPyArgSetter: + int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1 + DLPackPyObjectCExporter dlpack_c_exporter + + ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1 + # The main call function + int TVMFFIPyFuncCall( + TVMFFIPyArgSetterFactory setter_factory, + void* chandle, + PyObject* py_arg_tuple, + TVMFFIAny* result, + int* c_api_ret_code + ) except -1 + + int TVMFFIPyCallFieldSetter( + TVMFFIPyArgSetterFactory setter_factory, + TVMFFIFieldSetter field_setter, + void* field_ptr, + PyObject* py_arg, + int* c_api_ret_code + ) except -1 + + int TVMFFIPyPyObjectToFFIAny( + TVMFFIPyArgSetterFactory setter_factory, + PyObject* py_arg, + TVMFFIAny* out, + int* c_api_ret_code + ) except -1 + + size_t TVMFFIPyGetDispatchMapSize() noexcept + + void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, TVMFFIObjectHandle arg) noexcept + void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) noexcept + # the predefined setters for common POD types + int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 + int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 + int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 + int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, TVMFFIAny* out) except -1 + + cdef class ByteArrayArg: cdef TVMFFIByteArray cdata cdef object py_data diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index 71591d95267d..b77b19a2eabb 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -29,6 +29,9 @@ else: torch = None +_torch_dlpack_c_exporter_ptr = None + + cdef inline object make_ret_small_str(TVMFFIAny result): """convert small string to return value.""" cdef TVMFFIByteArray bytes @@ -45,7 +48,6 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result): cdef inline object make_ret(TVMFFIAny result): """convert result to return value.""" - # TODO: Implement cdef int32_t type_index type_index = result.type_index if type_index == kTVMFFITensor: @@ -55,7 +57,8 @@ cdef inline object make_ret(TVMFFIAny result): return make_ret_opaque_object(result) elif type_index >= kTVMFFIStaticObjectBegin: return make_ret_object(result) - elif type_index == kTVMFFINone: + # the following code should be optimized to switch case + if type_index == kTVMFFINone: return None elif type_index == kTVMFFIBool: return bool(result.v_int64) @@ -84,197 +87,325 @@ cdef inline object make_ret(TVMFFIAny result): raise ValueError("Unhandled type index %d" % type_index) -cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, - int* ctx_dev_type, int* ctx_dev_id, TVMFFIStreamHandle* ctx_stream) except -1: - """Pack arguments into c args tvm call accept""" - cdef unsigned long long temp_ptr - cdef DLTensor* temp_dltensor - cdef int is_cuda = 0 - - for i, arg in enumerate(py_args): - # clear the value to ensure zero padding on 32bit platforms - if sizeof(void*) != 8: - out[i].v_int64 = 0 - out[i].zero_padding = 0 - - if isinstance(arg, Tensor): - if (arg).chandle != NULL: - out[i].type_index = kTVMFFITensor - out[i].v_ptr = (arg).chandle - else: - out[i].type_index = kTVMFFIDLTensorPtr - out[i].v_ptr = (arg).cdltensor - elif isinstance(arg, Object): - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - elif torch is not None and isinstance(arg, torch.Tensor): - is_cuda = arg.is_cuda - arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg)) - out[i].type_index = kTVMFFITensor - out[i].v_ptr = (arg).chandle - temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) - # record the stream and device for torch context - if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1: - ctx_dev_type[0] = temp_dltensor.device.device_type - ctx_dev_id[0] = temp_dltensor.device.device_id - # This is an API that dynamo and other uses to get the raw stream from torch - temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id) - ctx_stream[0] = temp_ptr - temp_args.append(arg) - elif hasattr(arg, "__dlpack__"): - ffi_arg = from_dlpack(arg) - out[i].type_index = kTVMFFITensor - out[i].v_ptr = (ffi_arg).chandle - # record the stream from the source framework context when possible - temp_dltensor = TVMFFITensorGetDLTensorPtr((ffi_arg).chandle) - if (temp_dltensor.device.device_type != kDLCPU and - ctx_dev_type != NULL and - ctx_dev_type[0] == -1): - # __tvm_ffi_env_stream__ returns the expected stream that should be set - # through TVMFFIEnvSetCurrentStream when calling a TVM FFI function - if hasattr(arg, "__tvm_ffi_env_stream__"): - # Ideally projects should directly setup their stream context API - # write through by also calling TVMFFIEnvSetCurrentStream - # so we do not need this protocol to do exchange - ctx_dev_type[0] = temp_dltensor.device.device_type - ctx_dev_id[0] = temp_dltensor.device.device_id - temp_ptr= arg.__tvm_ffi_env_stream__() - ctx_stream[0] = temp_ptr - temp_args.append(ffi_arg) - elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: - arg = arg.__tvm_ffi_object__ - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - elif isinstance(arg, bool): - # A python `bool` is a subclass of `int`, so this check - # must occur before `Integral`. - out[i].type_index = kTVMFFIBool - out[i].v_int64 = arg - elif isinstance(arg, Integral): - out[i].type_index = kTVMFFIInt - out[i].v_int64 = arg - elif isinstance(arg, float): - out[i].type_index = kTVMFFIFloat - out[i].v_float64 = arg - elif isinstance(arg, _CLASS_DTYPE): - # dtype is a subclass of str, so this check occur before str - arg = arg.__tvm_ffi_dtype__ - out[i].type_index = kTVMFFIDataType - out[i].v_dtype = (arg).cdtype - elif isinstance(arg, _CLASS_DEVICE): - out[i].type_index = kTVMFFIDevice - out[i].v_device = (arg).cdevice - elif isinstance(arg, str): - tstr = c_str(arg) - out[i].type_index = kTVMFFIRawStr - out[i].v_c_str = tstr - temp_args.append(tstr) - elif arg is None: - out[i].type_index = kTVMFFINone - out[i].v_int64 = 0 - elif isinstance(arg, Real): - out[i].type_index = kTVMFFIFloat - out[i].v_float64 = arg - elif isinstance(arg, (bytes, bytearray)): - arg = ByteArrayArg(arg) - out[i].type_index = kTVMFFIByteArrayPtr - out[i].v_int64 = 0 - out[i].v_ptr = (arg).cptr() - temp_args.append(arg) - elif isinstance(arg, (list, tuple, dict, ObjectConvertible)): - arg = _FUNC_CONVERT_TO_OBJECT(arg) - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - elif isinstance(arg, ctypes.c_void_p): - out[i].type_index = kTVMFFIOpaquePtr - out[i].v_ptr = c_handle(arg) - elif isinstance(arg, Exception): - arg = _convert_to_ffi_error(arg) - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - elif isinstance(arg, ObjectRValueRef): - out[i].type_index = kTVMFFIObjectRValueRef - out[i].v_ptr = &(((arg.obj)).chandle) - elif callable(arg): - arg = _convert_to_ffi_func(arg) - out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - else: - arg = _convert_to_opaque_object(arg) - out[i].type_index = kTVMFFIOpaquePyObject - out[i].v_ptr = (arg).chandle - temp_args.append(arg) - - -cdef inline int FuncCall3(void* chandle, - tuple args, - TVMFFIAny* result, - int* c_api_ret_code) except -1: - # fast path with stack alloca for less than 3 args - cdef TVMFFIAny[3] packed_args - cdef int nargs = len(args) - cdef int ctx_dev_type = -1 - cdef int ctx_dev_id = 0 - cdef TVMFFIStreamHandle ctx_stream = NULL - cdef TVMFFIStreamHandle prev_stream = NULL - temp_args = [] - make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) - with nogil: - if ctx_dev_type != -1: - # set the stream based on ctx stream - c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) - if c_api_ret_code[0] != 0: - return 0 - c_api_ret_code[0] = TVMFFIFunctionCall( - chandle, &packed_args[0], nargs, result - ) - # restore the original stream if it is not the same as the context stream - if ctx_dev_type != -1 and prev_stream != ctx_stream: - # restore the original stream - c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) - if c_api_ret_code[0] != 0: - return 0 +##---------------------------------------------------------------------------- +## Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_ +##---------------------------------------------------------------------------- +cdef int TVMFFIPyArgSetterTensor_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* arg, TVMFFIAny* out +) except -1: + if (arg).chandle != NULL: + out.type_index = kTVMFFITensor + out.v_ptr = (arg).chandle + else: + out.type_index = kTVMFFIDLTensorPtr + out.v_ptr = (arg).cdltensor + return 0 + + +cdef int TVMFFIPyArgSetterObject_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* arg, TVMFFIAny* out +) except -1: + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle + return 0 + + +cdef int TVMFFIPyArgSetterDLPackCExporter_( + TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx, + PyObject* arg, TVMFFIAny* out +) except -1: + cdef DLManagedTensorVersioned* temp_managed_tensor + cdef TVMFFIObjectHandle temp_chandle + cdef TVMFFIStreamHandle env_stream = NULL + + if ctx.device_id != -1: + # already queried device, do not do it again, pass NULL to stream + if (this.dlpack_c_exporter)(arg, &temp_managed_tensor, NULL) != 0: + return -1 + else: + # query string on the envrionment stream + if (this.dlpack_c_exporter)(arg, &temp_managed_tensor, &env_stream) != 0: + return -1 + # If device is not CPU, we should set the device type and id + if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU: + ctx.stream = env_stream + ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type + ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id + # run conversion + if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0, &temp_chandle) != 0: + raise BufferError("Failed to convert DLManagedTensorVersioned to ffi.Tensor") + out.type_index = kTVMFFITensor + out.v_ptr = temp_chandle + TVMFFIPyPushTempFFIObject(ctx, temp_chandle) + return 0 + + +cdef int TVMFFIPyArgSetterTorch_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Current setter for torch.Tensor, go through python and not as fast as c exporter""" + cdef object arg = py_arg + is_cuda = arg.is_cuda + arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg)) + out.type_index = kTVMFFITensor + out.v_ptr = (arg).chandle + temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) + # record the stream and device for torch context + if is_cuda and ctx.device_type != -1: + ctx.device_type = temp_dltensor.device.device_type + ctx.device_id = temp_dltensor.device.device_id + # This is an API that dynamo and other uses to get the raw stream from torch + temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id) + ctx.stream = temp_ptr + # push to temp and clear the handle + TVMFFIPyPushTempPyObject(ctx, arg) + return 0 + + +cdef int TVMFFIPyArgSetterDLPack_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for __dlpack__ mechanism through python, not as fast as c exporter""" + cdef TVMFFIObjectHandle temp_chandle + cdef object arg = py_arg + _from_dlpack_universal(arg, 0, 0, &temp_chandle) + out.type_index = kTVMFFITensor + out.v_ptr = temp_chandle + # record the stream from the source framework context when possible + temp_dltensor = TVMFFITensorGetDLTensorPtr(temp_chandle) + if (temp_dltensor.device.device_type != kDLCPU and + ctx.device_type != -1): + # __tvm_ffi_env_stream__ returns the expected stream that should be set + # through TVMFFIEnvSetCurrentStream when calling a TVM FFI function + if hasattr(arg, "__tvm_ffi_env_stream__"): + # Ideally projects should directly setup their stream context API + # write through by also calling TVMFFIEnvSetCurrentStream + # so we do not need this protocol to do exchange + ctx.device_type = temp_dltensor.device.device_type + ctx.device_id = temp_dltensor.device.device_id + temp_ptr= arg.__tvm_ffi_env_stream__() + ctx.stream = temp_ptr + TVMFFIPyPushTempFFIObject(ctx, temp_chandle) + return 0 + + +cdef int TVMFFIPyArgSetterDType_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for dtype""" + cdef object arg = py_arg + # dtype is a subclass of str, so this check occur before str + arg = arg.__tvm_ffi_dtype__ + out.type_index = kTVMFFIDataType + out.v_dtype = (arg).cdtype + return 0 + + +cdef int TVMFFIPyArgSetterDevice_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for device""" + cdef object arg = py_arg + out.type_index = kTVMFFIDevice + out.v_device = (arg).cdevice + return 0 + + +cdef int TVMFFIPyArgSetterStr_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for str""" + cdef object arg = py_arg + + if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: + arg = arg.__tvm_ffi_object__ + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle + return 0 + + tstr = c_str(arg) + out.type_index = kTVMFFIRawStr + out.v_c_str = tstr + TVMFFIPyPushTempPyObject(ctx, tstr) return 0 -cdef inline int FuncCall(void* chandle, - tuple args, - TVMFFIAny* result, - int* c_api_ret_code) except -1: - cdef int nargs = len(args) - cdef int ctx_dev_type = -1 - cdef int ctx_dev_id = 0 - cdef TVMFFIStreamHandle ctx_stream = NULL - cdef TVMFFIStreamHandle prev_stream = NULL +cdef int TVMFFIPyArgSetterBytes_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for bytes""" + cdef object arg = py_arg - if nargs <= 3: - FuncCall3(chandle, args, result, c_api_ret_code) + if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: + arg = arg.__tvm_ffi_object__ + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle return 0 - cdef vector[TVMFFIAny] packed_args - packed_args.resize(nargs) + arg = ByteArrayArg(arg) + out.type_index = kTVMFFIByteArrayPtr + out.v_int64 = 0 + out.v_ptr = (arg).cptr() + TVMFFIPyPushTempPyObject(ctx, arg) + return 0 + + +cdef int TVMFFIPyArgSetterCtypesVoidPtr_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for ctypes.c_void_p""" + out.type_index = kTVMFFIOpaquePtr + out.v_ptr = c_handle(py_arg) + return 0 + + +cdef int TVMFFIPyArgSetterObjectRValueRef_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for ObjectRValueRef""" + cdef object arg = py_arg + out.type_index = kTVMFFIObjectRValueRef + out.v_ptr = &(((arg.obj)).chandle) + return 0 + - temp_args = [] - make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, &ctx_stream) +cdef int TVMFFIPyArgSetterCallable_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for Callable""" + cdef object arg = py_arg + arg = _convert_to_ffi_func(arg) + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle + TVMFFIPyPushTempPyObject(ctx, arg) + return 0 - with nogil: - if ctx_dev_type != -1: - c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) - if c_api_ret_code[0] != 0: - return 0 - c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result) - # restore the original stream if it is not the same as the context stream - if ctx_dev_type != -1 and prev_stream != ctx_stream: - c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) - if c_api_ret_code[0] != 0: - return 0 +cdef int TVMFFIPyArgSetterException_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for Exception""" + cdef object arg = py_arg + arg = _convert_to_ffi_error(arg) + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle + TVMFFIPyPushTempPyObject(ctx, arg) return 0 +cdef int TVMFFIPyArgSetterFallback_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Fallback setter for all other types""" + cdef object arg = py_arg + # fallback must contain PyNativeObject check + if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: + arg = arg.__tvm_ffi_object__ + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle + elif isinstance(arg, (list, tuple, dict, ObjectConvertible)): + arg = _FUNC_CONVERT_TO_OBJECT(arg) + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle + TVMFFIPyPushTempPyObject(ctx, arg) + else: + arg = _convert_to_opaque_object(arg) + out.type_index = kTVMFFIOpaquePyObject + out.v_ptr = (arg).chandle + TVMFFIPyPushTempPyObject(ctx, arg) + + +cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) except -1: + """ + Factory function that creates an argument setter for a given Python argument type. + """ + # NOTE: the order of checks matter here + # becase each argument may satisfy multiple checks + # priortize native types over external types + cdef object arg = value + cdef long long temp_ptr + if arg is None: + out.func = TVMFFIPyArgSetterNone_ + return 0 + if isinstance(arg, Tensor): + out.func = TVMFFIPyArgSetterTensor_ + return 0 + if isinstance(arg, Object): + out.func = TVMFFIPyArgSetterObject_ + return 0 + if isinstance(arg, ObjectRValueRef): + out.func = TVMFFIPyArgSetterObjectRValueRef_ + return 0 + # external tensors + if hasattr(arg, "__dlpack_c_exporter__"): + out.func = TVMFFIPyArgSetterDLPackCExporter_ + temp_ptr = arg.__dlpack_c_exporter__ + out.dlpack_c_exporter = temp_ptr + return 0 + if torch is not None and isinstance(arg, torch.Tensor): + if _torch_dlpack_c_exporter_ptr is not None: + temp_ptr = _torch_dlpack_c_exporter_ptr + out.func = TVMFFIPyArgSetterDLPackCExporter_ + out.dlpack_c_exporter = temp_ptr + else: + out.func = TVMFFIPyArgSetterTorch_ + return 0 + if hasattr(arg, "__dlpack__"): + out.func = TVMFFIPyArgSetterDLPack_ + return 0 + if isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + out.func = TVMFFIPyArgSetterBool_ + return 0 + if isinstance(arg, Integral): + out.func = TVMFFIPyArgSetterInt_ + return 0 + if isinstance(arg, Real): + out.func = TVMFFIPyArgSetterFloat_ + return 0 + # dtype is a subclass of str, so this check must occur before str + if isinstance(arg, _CLASS_DTYPE): + out.func = TVMFFIPyArgSetterDType_ + return 0 + if isinstance(arg, _CLASS_DEVICE): + out.func = TVMFFIPyArgSetterDevice_ + return 0 + if isinstance(arg, str): + out.func = TVMFFIPyArgSetterStr_ + return 0 + if isinstance(arg, (bytes, bytearray)): + out.func = TVMFFIPyArgSetterBytes_ + return 0 + if isinstance(arg, ctypes.c_void_p): + out.func = TVMFFIPyArgSetterCtypesVoidPtr_ + return 0 + if callable(arg): + out.func = TVMFFIPyArgSetterCallable_ + return 0 + if isinstance(arg, Exception): + out.func = TVMFFIPyArgSetterException_ + return 0 + # default to opaque object + out.func = TVMFFIPyArgSetterFallback_ + return 0 + +#--------------------------------------------------------------------------------------------- +## Implementation of function calling +#--------------------------------------------------------------------------------------------- cdef inline int ConstructorCall(void* constructor_handle, tuple args, void** handle) except -1: @@ -284,7 +415,7 @@ cdef inline int ConstructorCall(void* constructor_handle, # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone result.type_index = kTVMFFINone result.v_int64 = 0 - FuncCall(constructor_handle, args, &result, &c_api_ret_code) + TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory_, constructor_handle, args, &result, &c_api_ret_code) CHECK_CALL(c_api_ret_code) handle[0] = result.v_ptr return 0 @@ -304,7 +435,12 @@ class Function(Object): # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone result.type_index = kTVMFFINone result.v_int64 = 0 - FuncCall((self).chandle, args, &result, &c_api_ret_code) + TVMFFIPyFuncCall( + TVMFFIPyArgSetterFactory_, + (self).chandle, args, + &result, + &c_api_ret_code + ) # NOTE: logic is same as check_call # directly inline here to simplify traceback if c_api_ret_code == 0: @@ -336,13 +472,15 @@ cdef class FieldSetter: cdef int64_t offset def __call__(self, Object obj, value): - cdef TVMFFIAny[1] packed_args cdef int c_api_ret_code cdef void* field_ptr = ((obj).chandle) + self.offset - cdef int nargs = 1 - temp_args = [] - make_args((value,), &packed_args[0], temp_args, NULL, NULL, NULL) - c_api_ret_code = self.setter(field_ptr, &packed_args[0]) + TVMFFIPyCallFieldSetter( + TVMFFIPyArgSetterFactory_, + self.setter, + field_ptr, + value, + &c_api_ret_code + ) # NOTE: logic is same as check_call # directly inline here to simplify traceback if c_api_ret_code == 0: @@ -466,6 +604,7 @@ cdef int tvm_ffi_callback(void* context, TVMFFIAny* result) noexcept with gil: cdef list pyargs cdef TVMFFIAny temp_result + cdef int c_api_ret_code local_pyfunc = (context) pyargs = [] for i in range(num_args): @@ -474,16 +613,21 @@ cdef int tvm_ffi_callback(void* context, try: rv = local_pyfunc(*pyargs) + TVMFFIPyPyObjectToFFIAny( + TVMFFIPyArgSetterFactory_, + rv, + result, + &c_api_ret_code + ) + if c_api_ret_code == 0: + return 0 + elif c_api_ret_code == -2: + raise_existing_error() + return -1 except Exception as err: set_last_ffi_error(err) return -1 - temp_args = [] - make_args((rv,), &temp_result, temp_args, NULL, NULL, NULL) - CHECK_CALL(TVMFFIAnyViewToOwnedAny(&temp_result, result)) - - return 0 - def _convert_to_ffi_func(object pyfunc): """Convert a python function to TVM FFI function""" @@ -513,6 +657,12 @@ def _convert_to_opaque_object(object pyobject): return ret +def _print_debug_info(): + """Get the size of the dispatch map""" + cdef size_t size = TVMFFIPyGetDispatchMapSize() + print(f"TVMFFIPyGetDispatchMapSize: {size}") + + _STR_CONSTRUCTOR = _get_global_func("ffi.String", False) _BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False) _OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi index 2072ad056797..fca6cc0bbc08 100644 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ b/ffi/python/tvm_ffi/cython/tensor.pxi @@ -43,6 +43,21 @@ cdef void _c_dlpack_versioned_deleter(object pycaps): dltensor.deleter(dltensor) +cdef inline object _from_dlpack_intptr( + void* dlpack +): + cdef TVMFFIObjectHandle chandle + cdef DLManagedTensor* ptr = dlpack + cdef int c_api_ret_code + cdef int c_req_alignment = 0 + cdef int c_req_contiguous = 0 + with nogil: + c_api_ret_code = TVMFFITensorFromDLPack( + ptr, c_req_alignment, c_req_contiguous, &chandle) + CHECK_CALL(c_api_ret_code) + return make_tensor_from_chandle(chandle) + + cdef inline int _from_dlpack( object dltensor, int require_alignment, int require_contiguous, TVMFFIObjectHandle* out @@ -86,27 +101,10 @@ cdef inline int _from_dlpack_versioned( raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be consumed once") -def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): - """ - Convert an external tensor to an Tensor. - - Parameters - ---------- - ext_tensor : object - The external tensor to convert. - - require_alignment : int - The minimum required alignment to check for the tensor. - - require_contiguous : bool - Whether to check for contiguous memory. - - Returns - ------- - tensor : :py:class:`tvm_ffi.Tensor` - The converted tensor. - """ - cdef TVMFFIObjectHandle chandle +cdef inline int _from_dlpack_universal( + object ext_tensor, int require_alignment, + int require_contiguous, TVMFFIObjectHandle* out +) except -1: # as of most frameworks do not yet support v1.1 # move to false as most frameworks get upgraded. cdef int favor_legacy_dlpack = True @@ -114,10 +112,10 @@ def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): if hasattr(ext_tensor, '__dlpack__'): if favor_legacy_dlpack: _from_dlpack( - ext_tensor.__dlpack__(), + ext_tensor.__dlpack__(), require_alignment, require_contiguous, - &chandle + out ) else: try: @@ -125,14 +123,14 @@ def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): ext_tensor.__dlpack__(max_version=__dlpack_version__), require_alignment, require_contiguous, - &chandle + out ) except TypeError: _from_dlpack( ext_tensor.__dlpack__(), require_alignment, require_contiguous, - &chandle + out ) else: if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned): @@ -140,17 +138,41 @@ def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): ext_tensor, require_alignment, require_contiguous, - &chandle + out ) elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor): _from_dlpack( ext_tensor, require_alignment, require_contiguous, - &chandle + out ) else: raise TypeError("Expect from_dlpack to take either a compatible tensor or PyCapsule") + + +def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False): + """ + Convert an external tensor to an Tensor. + + Parameters + ---------- + ext_tensor : object + The external tensor to convert. + + require_alignment : int + The minimum required alignment to check for the tensor. + + require_contiguous : bool + Whether to check for contiguous memory. + + Returns + ------- + tensor : :py:class:`tvm_ffi.Tensor` + The converted tensor. + """ + cdef TVMFFIObjectHandle chandle + _from_dlpack_universal(ext_tensor, require_alignment, require_contiguous, &chandle) return make_tensor_from_chandle(chandle) @@ -260,9 +282,33 @@ _set_class_tensor(Tensor) _register_object_by_index(kTVMFFITensor, Tensor) + +cdef int _dltensor_test_wrapper_dlpack_c_exporter( + void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream +) except -1: + cdef object ref_obj = (obj) + cdef DLTensorTestWrapper wrapper = ref_obj + cdef TVMFFIStreamHandle current_stream + + if env_stream != NULL: + env_stream[0] = TVMFFIEnvGetCurrentStream( + wrapper.tensor.cdltensor.device.device_type, + wrapper.tensor.cdltensor.device.device_id + ) + return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out) + + +def _dltensor_test_wrapper_dlpack_c_exporter_as_intptr(): + cdef DLPackPyObjectCExporter converter_func = _dltensor_test_wrapper_dlpack_c_exporter + cdef void* temp_ptr = converter_func + cdef long long temp_int_ptr = temp_ptr + return temp_int_ptr + + cdef class DLTensorTestWrapper: """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose. """ + __dlpack_c_exporter__ = _dltensor_test_wrapper_dlpack_c_exporter_as_intptr() cdef Tensor tensor def __init__(self, tensor): self.tensor = tensor diff --git a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h new file mode 100644 index 000000000000..32ded385bae8 --- /dev/null +++ b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file tvm_ffi_python_helpers.h + * \brief C++ based helpers for the Python FFI call to optimize performance. + */ +#ifndef TVM_FFI_PYTHON_HELPERS_H_ +#define TVM_FFI_PYTHON_HELPERS_H_ + +#include +#include +#include + +#include +#include + +///-------------------------------------------------------------------------------- +/// We deliberately designed the data structure and function to be C-style +// prefixed with TVMFFIPy so they can be easily invoked through Cython. +///-------------------------------------------------------------------------------- +/*! + * \brief Context for each ffi call to track the stream, device and temporary arguments. + */ +struct TVMFFIPyCallContext { + /*! \brief The workspace for the packed arguments */ + TVMFFIAny* packed_args = nullptr; + /*! \brief Detected device type, if any */ + int device_type = -1; + /*! \brief Detected device id, if any */ + int device_id = 0; + /*! \brief Detected stream, if any */ + void* stream = nullptr; + /*! \brief the temporary arguments to be recycled */ + void** temp_ffi_objects = nullptr; + /*! \brief the number of temporary arguments */ + int num_temp_ffi_objects = 0; + /*! \brief the temporary arguments to be recycled */ + void** temp_py_objects = nullptr; + /*! \brief the number of temporary arguments */ + int num_temp_py_objects = 0; +}; + +/*! + * \brief C-style function pointer to speed convert a Tensor to a DLManagedTensorVersioned. + * \param py_obj The Python object to convert, this should be PyObject* + * \param out The output DLManagedTensorVersioned. + * \param env_stream Outputs the current context stream of the device provided by the tensor. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + * \note We use void* to avoid dependency on Python.h so this specific type is + * not dependent on Python.h and can be copied to dlpack.h + */ +typedef int (*DLPackPyObjectCExporter)(void* py_obj, DLManagedTensorVersioned** out, + void** env_stream); + +/*! \brief Argument setter for a given python argument. */ +struct TVMFFIPyArgSetter { + /*! + * \brief Function pointer to invoke the setter. + * \param self Pointer to this, this should be TVMFFIPyArgSetter* + * \param call_ctx The call context. + * \param arg The python argument to be set + * \param out The output argument. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + */ + int (*func)(TVMFFIPyArgSetter* self, TVMFFIPyCallContext* call_ctx, PyObject* arg, + TVMFFIAny* out); + /*! + * \brief Optional DLPack exporter for for setters that leverages DLPack protocol. + */ + DLPackPyObjectCExporter dlpack_c_exporter{nullptr}; + /*! + * \brief Invoke the setter. + * \param call_ctx The call context. + * \param arg The python argument to be set + * \param out The output argument. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + */ + int operator()(TVMFFIPyCallContext* call_ctx, PyObject* arg, TVMFFIAny* out) const { + return (*func)(const_cast(this), call_ctx, arg, out); + } +}; + +//--------------------------------------------------------------------------------------------- +// The following section contains predefined setters for common POD types +// They ar not meant to be used directly, but instead being registered to TVMFFIPyCallManager +//--------------------------------------------------------------------------------------------- +int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, + TVMFFIAny* out) noexcept { + out->type_index = kTVMFFIFloat; + // this function getsdispatched when type is already float, so no need to worry about error + out->v_float64 = PyFloat_AsDouble(arg); + return 0; +} + +int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, + TVMFFIAny* out) noexcept { + int overflow = 0; + out->type_index = kTVMFFIInt; + out->v_int64 = PyLong_AsLongLongAndOverflow(arg, &overflow); + + if (overflow != 0) { + PyErr_SetString(PyExc_OverflowError, "Python int too large to convert to int64_t"); + return -1; + } + return 0; +} + +int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, + TVMFFIAny* out) noexcept { + out->type_index = kTVMFFIBool; + // this function getsdispatched when type is already bool, so no need to worry about error + out->v_int64 = PyLong_AsLong(arg); + return 0; +} + +int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, PyObject* arg, + TVMFFIAny* out) noexcept { + out->type_index = kTVMFFINone; + out->v_int64 = 0; + return 0; +} + +//--------------------------------------------------------------------------------------------- +// The following section contains the dispatcher logic for function calling +//--------------------------------------------------------------------------------------------- +/*! + * \brief Factory function that creates an argument setter for a given Python argument type. + * + * This factory function analyzes a Python argument and creates an appropriate setter + * that can convert Python objects of the same type to C arguments for TVM FFI calls. + * The setter will be cached for future use for setting argument of the same type. + * + * \param arg The Python argument value used as a type example. + * \param out Output parameter that receives the created argument setter. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + * + * \note This is a callback function supplied by the caller. The factory must satisfy + * the invariance that the same setter can be used for other arguments with + * the same type as the provided example argument. + */ +typedef int (*TVMFFIPyArgSetterFactory)(PyObject* arg, TVMFFIPyArgSetter* out); + +/*! + * \brief A manager class that handles python ffi calls. + */ +class TVMFFIPyCallManager { + public: + /*! + * \brief Get the thread local call manager. + * \return The thread local call manager. + */ + static TVMFFIPyCallManager* ThreadLocal() { + static thread_local TVMFFIPyCallManager inst; + return &inst; + } + /*! + * \brief auxiliary class that manages the call stack in RAII manner. + * + * In most cases, it will try to allocate from temp_stack, + * then allocate from heap if the request goes beyond the stack size. + */ + class CallStack : public TVMFFIPyCallContext { + public: + CallStack(TVMFFIPyCallManager* manager, int64_t num_args) : manager_ptr_(manager) { + static_assert(sizeof(TVMFFIAny) >= (sizeof(void*) * 2)); + static_assert(alignof(TVMFFIAny) % alignof(void*) == 0); + old_stack_top_ = manager->stack_top_; + int64_t requested_count = num_args * 2; + TVMFFIAny* stack_head = manager->temp_stack_.data() + manager->stack_top_; + if (manager->stack_top_ + requested_count > + static_cast(manager->temp_stack_.size())) { + // allocate from heap + heap_ptr_ = new TVMFFIAny[requested_count]; + stack_head = heap_ptr_; + } else { + manager->stack_top_ += requested_count; + } + this->packed_args = stack_head; + this->temp_ffi_objects = reinterpret_cast(stack_head + num_args); + this->temp_py_objects = this->temp_ffi_objects + num_args; + } + + ~CallStack() { + try { + // recycle the temporary arguments if any + for (int i = 0; i < this->num_temp_ffi_objects; ++i) { + TVMFFIObject* obj = static_cast(this->temp_ffi_objects[i]); + if (obj->deleter != nullptr) { + obj->deleter(obj, kTVMFFIObjectDeleterFlagBitMaskBoth); + } + } + for (int i = 0; i < this->num_temp_py_objects; ++i) { + Py_DecRef(static_cast(this->temp_py_objects[i])); + } + } catch (const std::exception& ex) { + // very rare, catch c++ exception and set python error + PyErr_SetString(PyExc_RuntimeError, ex.what()); + } + // now recycle the memory of the call stack + if (heap_ptr_ == nullptr) { + manager_ptr_->stack_top_ = old_stack_top_; + } else { + delete[] heap_ptr_; + } + } + + private: + /*! + *\brief The manager of the call stack + * If stored on stack, must set it to point to parent. + */ + TVMFFIPyCallManager* manager_ptr_ = nullptr; + /*! \brief The heap of the call stack */ + TVMFFIAny* heap_ptr_ = nullptr; + /*! \brief The old stack size */ + int64_t old_stack_top_ = 0; + }; + + /*! + * \brief Call a function with a variable number of arguments + * \param setter_factory The factory function to create the setter + * \param func_handle The handle of the function to call + * \param py_arg_tuple The arguments to the function + * \param result The result of the function + * \param c_api_ret_code The return code of the C-call + * \return 0 on when there is no python error, -1 on python error + * \note When an error happens on FFI side, we should return 0 and set c_api_ret_code + */ + int Call(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, + TVMFFIAny* result, int* c_api_ret_code) { + int64_t num_args = PyTuple_Size(py_arg_tuple); + if (num_args == -1) return -1; + try { + // allocate a call stack + CallStack ctx(this, num_args); + // Iterate over the arguments and set them + for (int64_t i = 0; i < num_args; ++i) { + PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i); + TVMFFIAny* c_arg = ctx.packed_args + i; + if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; + } + TVMFFIStreamHandle prev_stream = nullptr; + // setup stream context if needed + if (ctx.device_type != -1) { + c_api_ret_code[0] = + TVMFFIEnvSetCurrentStream(ctx.device_type, ctx.device_id, ctx.stream, &prev_stream); + // setting failed, directly return + if (c_api_ret_code[0] != 0) return 0; + } + // call the function + // release the GIL + Py_BEGIN_ALLOW_THREADS; + c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); + Py_END_ALLOW_THREADS; + // restore the original stream + if (ctx.device_type != -1 && prev_stream != ctx.stream) { + // always try recover first, even if error happens + if (TVMFFIEnvSetCurrentStream(ctx.device_type, ctx.device_id, prev_stream, nullptr) != 0) { + // recover failed, set python error + PyErr_SetString(PyExc_RuntimeError, "Failed to recover stream"); + return -1; + } + } + return 0; + } catch (const std::exception& ex) { + // very rare, catch c++ exception and set python error + PyErr_SetString(PyExc_RuntimeError, ex.what()); + return -1; + } + } + + int SetField(TVMFFIPyArgSetterFactory setter_factory, TVMFFIFieldSetter field_setter, + void* field_ptr, PyObject* py_arg, int* c_api_ret_code) { + try { + CallStack ctx(this, 1); + TVMFFIAny* c_arg = ctx.packed_args; + if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; + c_api_ret_code[0] = (*field_setter)(field_ptr, c_arg); + return 0; + } catch (const std::exception& ex) { + // very rare, catch c++ exception and set python error + PyErr_SetString(PyExc_RuntimeError, ex.what()); + return -1; + } + } + + int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, PyObject* py_arg, TVMFFIAny* out, + int* c_api_ret_code) { + try { + CallStack ctx(this, 1); + TVMFFIAny* c_arg = ctx.packed_args; + if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; + c_api_ret_code[0] = TVMFFIAnyViewToOwnedAny(c_arg, out); + return 0; + } catch (const std::exception& ex) { + // very rare, catch c++ exception and set python error + PyErr_SetString(PyExc_RuntimeError, ex.what()); + return -1; + } + } + /*! + * \brief Get the size of the dispatch map + * \return The size of the dispatch map + */ + size_t GetDispatchMapSize() const { return dispatch_map_.size(); } + + private: + TVMFFIPyCallManager() { + static constexpr size_t kDefaultDispatchCapacity = 32; + static constexpr size_t kDefaultStackSize = 32; + dispatch_map_.reserve(kDefaultDispatchCapacity); + temp_stack_.resize(kDefaultStackSize * 2); + } + /*! + * \brief Set an py_arg to out. + * \param setter_factory The factory function to create the setter + * \param ctx The call context + * \param py_arg The python argument to be set + * \param out The output argument + * \return 0 on success, -1 on failure + */ + int SetArgument(TVMFFIPyArgSetterFactory setter_factory, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out) { + PyTypeObject* py_type = Py_TYPE(py_arg); + // pre-zero the output argument, modulo the type index + out->type_index = kTVMFFINone; + out->zero_padding = 0; + out->v_int64 = 0; + // find the pre-cached setter + // This class is thread-local, so we don't need to worry about race condition + auto it = dispatch_map_.find(py_type); + if (it != dispatch_map_.end()) { + TVMFFIPyArgSetter setter = it->second; + // if error happens, propagate it back + if (setter(ctx, py_arg, out) != 0) return -1; + } else { + // no dispatch found, query and create a new one. + TVMFFIPyArgSetter setter; + // propagate python error back + if (setter_factory(py_arg, &setter) != 0) { + return -1; + } + // update dispatch table + dispatch_map_.emplace(py_type, setter); + if (setter(ctx, py_arg, out) != 0) return -1; + } + return 0; + } + // internal dispacher + std::unordered_map dispatch_map_; + // temp call stack + std::vector temp_stack_; + int64_t stack_top_ = 0; +}; + +/*! + * \brief Call a function with a variable number of arguments + * \param setter_factory The factory function to create the setter + * \param func_handle The handle of the function to call + * \param py_arg_tuple The arguments to the function + * \param result The result of the function + * \param c_api_ret_code The return code of the function + * \return 0 on success, nonzero on failure + */ +inline int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, + PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code) { + return TVMFFIPyCallManager::ThreadLocal()->Call(setter_factory, func_handle, py_arg_tuple, result, + c_api_ret_code); +} + +/*! + * \brief Set a field of a FFI object + * \param setter_factory The factory function to create the setter + * \param field_setter The field setter function + * \param field_ptr The pointer to the field + * \param py_arg The python argument to be set + * \param c_api_ret_code The return code of the function + * \return 0 on success, nonzero on failure + */ +inline int TVMFFIPyCallFieldSetter(TVMFFIPyArgSetterFactory setter_factory, + TVMFFIFieldSetter field_setter, void* field_ptr, + PyObject* py_arg, int* c_api_ret_code) { + return TVMFFIPyCallManager::ThreadLocal()->SetField(setter_factory, field_setter, field_ptr, + py_arg, c_api_ret_code); +} + +/*! + * \brief Convert a Python object to a FFI Any + * \param setter_factory The factory function to create the setter + * \param py_arg The python argument to be set + * \param out The output argument + * \param c_api_ret_code The return code of the function + * \return 0 on success, nonzero on failure + */ +inline int TVMFFIPyPyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, PyObject* py_arg, + TVMFFIAny* out, int* c_api_ret_code) { + return TVMFFIPyCallManager::ThreadLocal()->PyObjectToFFIAny(setter_factory, py_arg, out, + c_api_ret_code); +} + +/*! + * \brief Get the size of the dispatch map + * \return The size of the dispatch map + */ +inline size_t TVMFFIPyGetDispatchMapSize() { + return TVMFFIPyCallManager::ThreadLocal()->GetDispatchMapSize(); +} + +/*! + * \brief Push a temporary FFI object to the call context that will be recycled after the call + * \param ctx The call context + * \param arg The FFI object to push + */ +inline void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, TVMFFIObjectHandle arg) noexcept { + // invariance: each ArgSetter can have at most one temporary Python object + // so it ensures that we won't overflow the temporary Python object stack + ctx->temp_ffi_objects[ctx->num_temp_ffi_objects++] = arg; +} + +/*! + * \brief Push a temporary Python object to the call context that will be recycled after the call + * \param ctx The call context + * \param arg The Python object to push + */ +inline void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) noexcept { + // invariance: each ArgSetter can have at most one temporary Python object + // so it ensures that we won't overflow the temporary Python object stack + Py_IncRef(arg); + ctx->temp_py_objects[ctx->num_temp_py_objects++] = arg; +} +#endif // TVM_FFI_PYTHON_HELPERS_H_ diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py index 00581eb0f307..364afa1b5fdf 100644 --- a/ffi/scripts/benchmark_dlpack.py +++ b/ffi/scripts/benchmark_dlpack.py @@ -237,6 +237,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat): """ nop = tvm_ffi.get_global_func("testing.nop") nop(x, y, z) + eps = 1e-6 start = time.time() for i in range(repeat): nop(x, y, z) @@ -375,8 +376,19 @@ def bench_torch_get_current_stream(repeat, name, func): print_speed(f"torch.cuda.current_stream[{name}]", speed) +def populate_object_table(num_classes): + nop = tvm_ffi.get_global_func("testing.nop") + dummy_instances = [type(f"DummyClass{i}", (object,), {})() for i in range(num_classes)] + for instance in dummy_instances: + nop(instance) + + def main(): repeat = 10000 + # measures impact of object dispatch table size + # takeaway so far is that there is no impact on the performance + num_classes = 0 + populate_object_table(num_classes) print("-----------------------------") print("Benchmark f(x, y, z) overhead") print("-----------------------------") @@ -423,6 +435,10 @@ def main(): repeat, "cpp-extension", load_torch_get_current_cuda_stream() ) bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) + print("---------------------------------------------------") + print("Benchmark tvm_ffi.print_helper_info") + print("---------------------------------------------------") + tvm_ffi.core._print_debug_info() if __name__ == "__main__": diff --git a/ffi/src/ffi/extra/testing.cc b/ffi/src/ffi/extra/testing.cc index 1b2862a46c1d..54bf7ba35234 100644 --- a/ffi/src/ffi/extra/testing.cc +++ b/ffi/src/ffi/extra/testing.cc @@ -113,7 +113,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ refl::GlobalDef() .def("testing.test_raise_error", TestRaiseError) - .def_packed("testing.nop", [](PackedArgs args, Any* ret) { *ret = args[0]; }) + .def_packed("testing.nop", [](PackedArgs args, Any* ret) {}) .def_packed("testing.echo", [](PackedArgs args, Any* ret) { *ret = args[0]; }) .def_packed("testing.apply", TestApply) .def("testing.run_check_signal",