From aa2c271ab4fd64d1e5afbfdc2bdf5217e2c838b4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 9 Sep 2025 10:37:42 -0400 Subject: [PATCH] [FFI][ABI] Introduce generic stream exchange protocol This PR adds a __tvm_ffi_env_stream__ protocol for generic tensors to exchange env stream to tvm ffi. Also renames TVMFFIEnvSetStream to TVMFFIEnvSetCurrentStream. --- ffi/include/tvm/ffi/extra/c_env_api.h | 6 +- ffi/python/tvm_ffi/cython/base.pxi | 91 ++++++++++++++--------- ffi/python/tvm_ffi/cython/function.pxi | 29 ++++++-- ffi/python/tvm_ffi/cython/tensor.pxi | 24 ++++++ ffi/scripts/benchmark_dlpack.py | 26 ++++++- ffi/src/ffi/extra/stream_context.cc | 4 +- src/runtime/device_api.cc | 3 +- src/runtime/vm/cuda/cuda_graph_builtin.cc | 7 +- 8 files changed, 134 insertions(+), 56 deletions(-) diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h index 6f8e44bdfb9c..bd0d188155fe 100644 --- a/ffi/include/tvm/ffi/extra/c_env_api.h +++ b/ffi/include/tvm/ffi/extra/c_env_api.h @@ -49,9 +49,9 @@ typedef void* TVMFFIStreamHandle; * \note The stream is a weak reference that is cached/owned by the module. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream); +TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream); /*! * \brief FFI function to get the current stream for a device diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index f1cd77bc47e8..efb2225453f5 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -24,39 +24,24 @@ from cpython cimport PyErr_CheckSignals, PyGILState_Ensure, PyGILState_Release, from cpython cimport pycapsule, PyCapsule_Destructor from cpython cimport PyErr_SetNone - -# Cython binding for TVM FFI C API -cdef extern from "tvm/ffi/c_api.h": - cdef enum TVMFFITypeIndex: - kTVMFFIAny = -1 - kTVMFFINone = 0 - kTVMFFIInt = 1 - kTVMFFIBool = 2 - kTVMFFIFloat = 3 - kTVMFFIOpaquePtr = 4 - kTVMFFIDataType = 5 - kTVMFFIDevice = 6 - kTVMFFIDLTensorPtr = 7 - kTVMFFIRawStr = 8 - kTVMFFIByteArrayPtr = 9 - kTVMFFIObjectRValueRef = 10 - kTVMFFISmallStr = 11 - kTVMFFISmallBytes = 12 - kTVMFFIStaticObjectBegin = 64 - kTVMFFIObject = 64 - kTVMFFIStr = 65 - kTVMFFIBytes = 66 - kTVMFFIError = 67 - kTVMFFIFunction = 68 - kTVMFFIShape = 69 - kTVMFFITensor = 70 - kTVMFFIArray = 71 - kTVMFFIMap = 72 - kTVMFFIModule = 73 - kTVMFFIOpaquePyObject = 74 - - - ctypedef void* TVMFFIObjectHandle +cdef extern from "dlpack/dlpack.h": + cdef enum: + kDLCPU = 1, + kDLCUDA = 2, + kDLCUDAHost = 3, + kDLOpenCL = 4, + kDLVulkan = 7, + kDLMetal = 8, + kDLVPI = 9, + kDLROCM = 10, + kDLROCMHost = 11, + kDLExtDev = 12, + kDLCUDAManaged = 13, + kDLOneAPI = 14, + kDLWebGPU = 15, + kDLHexagon = 16, + kDLMAIA = 17 + kDLTrn = 18 ctypedef struct DLDataType: uint8_t code @@ -92,6 +77,40 @@ cdef extern from "tvm/ffi/c_api.h": void (*deleter)(DLManagedTensorVersioned* self) uint64_t flags + +# Cython binding for TVM FFI C API +cdef extern from "tvm/ffi/c_api.h": + cdef enum TVMFFITypeIndex: + kTVMFFIAny = -1 + kTVMFFINone = 0 + kTVMFFIInt = 1 + kTVMFFIBool = 2 + kTVMFFIFloat = 3 + kTVMFFIOpaquePtr = 4 + kTVMFFIDataType = 5 + kTVMFFIDevice = 6 + kTVMFFIDLTensorPtr = 7 + kTVMFFIRawStr = 8 + kTVMFFIByteArrayPtr = 9 + kTVMFFIObjectRValueRef = 10 + kTVMFFISmallStr = 11 + kTVMFFISmallBytes = 12 + kTVMFFIStaticObjectBegin = 64 + kTVMFFIObject = 64 + kTVMFFIStr = 65 + kTVMFFIBytes = 66 + kTVMFFIError = 67 + kTVMFFIFunction = 68 + kTVMFFIShape = 69 + kTVMFFITensor = 70 + kTVMFFIArray = 71 + kTVMFFIMap = 72 + kTVMFFIModule = 73 + kTVMFFIOpaquePyObject = 74 + + + ctypedef void* TVMFFIObjectHandle + ctypedef struct TVMFFIObject: int32_t type_index int32_t ref_counter @@ -219,9 +238,9 @@ cdef extern from "tvm/ffi/extra/c_env_api.h": int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil - int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream) nogil + int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream) nogil cdef class ByteArrayArg: diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index 28d4ba5a0094..71591d95267d 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -122,10 +122,25 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, ctx_stream[0] = temp_ptr temp_args.append(arg) elif hasattr(arg, "__dlpack__"): - arg = from_dlpack(arg) + ffi_arg = from_dlpack(arg) out[i].type_index = kTVMFFITensor - out[i].v_ptr = (arg).chandle - temp_args.append(arg) + out[i].v_ptr = (ffi_arg).chandle + # record the stream from the source framework context when possible + temp_dltensor = TVMFFITensorGetDLTensorPtr((ffi_arg).chandle) + if (temp_dltensor.device.device_type != kDLCPU and + ctx_dev_type != NULL and + ctx_dev_type[0] == -1): + # __tvm_ffi_env_stream__ returns the expected stream that should be set + # through TVMFFIEnvSetCurrentStream when calling a TVM FFI function + if hasattr(arg, "__tvm_ffi_env_stream__"): + # Ideally projects should directly setup their stream context API + # write through by also calling TVMFFIEnvSetCurrentStream + # so we do not need this protocol to do exchange + ctx_dev_type[0] = temp_dltensor.device.device_type + ctx_dev_id[0] = temp_dltensor.device.device_id + temp_ptr= arg.__tvm_ffi_env_stream__() + ctx_stream[0] = temp_ptr + temp_args.append(ffi_arg) elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: arg = arg.__tvm_ffi_object__ out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) @@ -210,7 +225,7 @@ cdef inline int FuncCall3(void* chandle, with nogil: if ctx_dev_type != -1: # set the stream based on ctx stream - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) + c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) if c_api_ret_code[0] != 0: return 0 c_api_ret_code[0] = TVMFFIFunctionCall( @@ -219,7 +234,7 @@ cdef inline int FuncCall3(void* chandle, # restore the original stream if it is not the same as the context stream if ctx_dev_type != -1 and prev_stream != ctx_stream: # restore the original stream - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) + c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) if c_api_ret_code[0] != 0: return 0 return 0 @@ -247,13 +262,13 @@ cdef inline int FuncCall(void* chandle, with nogil: if ctx_dev_type != -1: - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) + c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream) if c_api_ret_code[0] != 0: return 0 c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result) # restore the original stream if it is not the same as the context stream if ctx_dev_type != -1 and prev_stream != ctx_stream: - c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) + c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL) if c_api_ret_code[0] != 0: return 0 diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi index 4658422ca524..2072ad056797 100644 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ b/ffi/python/tvm_ffi/cython/tensor.pxi @@ -260,6 +260,30 @@ _set_class_tensor(Tensor) _register_object_by_index(kTVMFFITensor, Tensor) +cdef class DLTensorTestWrapper: + """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose. + """ + cdef Tensor tensor + def __init__(self, tensor): + self.tensor = tensor + + def __tvm_ffi_env_stream__(self): + cdef TVMFFIStreamHandle stream + cdef long long stream_as_int + cdef int c_api_ret_code + with nogil: + stream = TVMFFIEnvGetCurrentStream( + self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id) + stream_as_int = stream + return stream_as_int + + def __dlpack_device__(self): + return self.tensor.__dlpack_device__() + + def __dlpack__(self, *, **kwargs): + return self.tensor.__dlpack__(**kwargs) + + cdef inline object make_ret_dltensor(TVMFFIAny result): cdef DLTensor* dltensor dltensor = result.v_ptr diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py index 73fbe0f6ac22..00581eb0f307 100644 --- a/ffi/scripts/benchmark_dlpack.py +++ b/ffi/scripts/benchmark_dlpack.py @@ -44,11 +44,11 @@ def print_speed(name, speed): - print(f"{name:<40} {speed} sec/call") + print(f"{name:<60} {speed} sec/call") def print_error(name, error): - print(f"{name:<40} {error}") + print(f"{name:<60} {error}") def baseline_torch_add(repeat): @@ -122,7 +122,7 @@ def tvm_ffi_nop(repeat): nop(x, y, z) start = time.time() for i in range(repeat): - y = tvm_ffi.from_dlpack(x) + nop(x, y, z) end = time.time() print_speed("tvm_ffi.nop", (end - start) / repeat) @@ -275,6 +275,22 @@ def tvm_ffi_nop_autodlpack_from_numpy(repeat): bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z, repeat) +def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device): + """ + Measures overhead of running dlpack via auto convert by directly + take test wrapper as inputs. This effectively measure DLPack exchange in tvm ffi. + """ + x = tvm_ffi.from_dlpack(torch.arange(1, device=device)) + y = tvm_ffi.from_dlpack(torch.arange(1, device=device)) + z = tvm_ffi.from_dlpack(torch.arange(1, device=device)) + x = tvm_ffi.core.DLTensorTestWrapper(x) + y = tvm_ffi.core.DLTensorTestWrapper(y) + z = tvm_ffi.core.DLTensorTestWrapper(z) + bench_tvm_ffi_nop_autodlpack( + f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, repeat + ) + + def bench_to_dlpack(x, name, repeat): x.__dlpack__() start = time.time() @@ -367,7 +383,6 @@ def main(): baseline_numpy_add(repeat) baseline_torch_add(repeat) baseline_cupy_add(repeat) - tvm_ffi_nop(repeat) tvm_ffi_nop_from_torch_dlpack(repeat) tvm_ffi_nop_from_numpy_dlpack(repeat) tvm_ffi_self_dlpack_nop(repeat) @@ -377,6 +392,9 @@ def main(): tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True) tvm_ffi_nop_autodlpack_from_numpy(repeat) + tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu") + tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda") + tvm_ffi_nop(repeat) print("-------------------------------") print("Benchmark x.__dlpack__ overhead") print("-------------------------------") diff --git a/ffi/src/ffi/extra/stream_context.cc b/ffi/src/ffi/extra/stream_context.cc index d063efdef579..5a6afad4c1d8 100644 --- a/ffi/src/ffi/extra/stream_context.cc +++ b/ffi/src/ffi/extra/stream_context.cc @@ -66,8 +66,8 @@ class StreamContext { } // namespace ffi } // namespace tvm -int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { +int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, + TVMFFIStreamHandle* out_original_stream) { TVM_FFI_SAFE_CALL_BEGIN(); tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream, out_original_stream); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index fd7d651df2f4..e574ce14b004 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -165,7 +165,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { - TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr)); + TVM_FFI_CHECK_SAFE_CALL( + TVMFFIEnvSetCurrentStream(dev.device_type, dev.device_id, stream, nullptr)); } TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index a85ade2e1d8d..252841528152 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -118,13 +118,14 @@ class CUDACaptureStream { explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) { CUDA_CALL(cudaGetDevice(&device_id_)); TVM_FFI_CHECK_SAFE_CALL( - TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_, - reinterpret_cast(&prev_default_stream_))); + TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, capture_stream_, + reinterpret_cast(&prev_default_stream_))); CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); } ~CUDACaptureStream() noexcept(false) { cudaStreamEndCapture(capture_stream_, output_graph_); - TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr)); + TVM_FFI_CHECK_SAFE_CALL( + TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, prev_default_stream_, nullptr)); } private: