Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ffi/include/tvm/ffi/extra/c_env_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ typedef void* TVMFFIStreamHandle;
* \note The stream is a weak reference that is cached/owned by the module.
* \return 0 when success, nonzero when failure happens
*/
TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream);
TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream);

/*!
* \brief FFI function to get the current stream for a device
Expand Down
91 changes: 55 additions & 36 deletions ffi/python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,24 @@ from cpython cimport PyErr_CheckSignals, PyGILState_Ensure, PyGILState_Release,
from cpython cimport pycapsule, PyCapsule_Destructor
from cpython cimport PyErr_SetNone


# Cython binding for TVM FFI C API
cdef extern from "tvm/ffi/c_api.h":
cdef enum TVMFFITypeIndex:
kTVMFFIAny = -1
kTVMFFINone = 0
kTVMFFIInt = 1
kTVMFFIBool = 2
kTVMFFIFloat = 3
kTVMFFIOpaquePtr = 4
kTVMFFIDataType = 5
kTVMFFIDevice = 6
kTVMFFIDLTensorPtr = 7
kTVMFFIRawStr = 8
kTVMFFIByteArrayPtr = 9
kTVMFFIObjectRValueRef = 10
kTVMFFISmallStr = 11
kTVMFFISmallBytes = 12
kTVMFFIStaticObjectBegin = 64
kTVMFFIObject = 64
kTVMFFIStr = 65
kTVMFFIBytes = 66
kTVMFFIError = 67
kTVMFFIFunction = 68
kTVMFFIShape = 69
kTVMFFITensor = 70
kTVMFFIArray = 71
kTVMFFIMap = 72
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74


ctypedef void* TVMFFIObjectHandle
cdef extern from "dlpack/dlpack.h":
cdef enum:
kDLCPU = 1,
kDLCUDA = 2,
kDLCUDAHost = 3,
kDLOpenCL = 4,
kDLVulkan = 7,
kDLMetal = 8,
kDLVPI = 9,
kDLROCM = 10,
kDLROCMHost = 11,
kDLExtDev = 12,
kDLCUDAManaged = 13,
kDLOneAPI = 14,
kDLWebGPU = 15,
kDLHexagon = 16,
kDLMAIA = 17
kDLTrn = 18

ctypedef struct DLDataType:
uint8_t code
Expand Down Expand Up @@ -92,6 +77,40 @@ cdef extern from "tvm/ffi/c_api.h":
void (*deleter)(DLManagedTensorVersioned* self)
uint64_t flags


# Cython binding for TVM FFI C API
cdef extern from "tvm/ffi/c_api.h":
cdef enum TVMFFITypeIndex:
kTVMFFIAny = -1
kTVMFFINone = 0
kTVMFFIInt = 1
kTVMFFIBool = 2
kTVMFFIFloat = 3
kTVMFFIOpaquePtr = 4
kTVMFFIDataType = 5
kTVMFFIDevice = 6
kTVMFFIDLTensorPtr = 7
kTVMFFIRawStr = 8
kTVMFFIByteArrayPtr = 9
kTVMFFIObjectRValueRef = 10
kTVMFFISmallStr = 11
kTVMFFISmallBytes = 12
kTVMFFIStaticObjectBegin = 64
kTVMFFIObject = 64
kTVMFFIStr = 65
kTVMFFIBytes = 66
kTVMFFIError = 67
kTVMFFIFunction = 68
kTVMFFIShape = 69
kTVMFFITensor = 70
kTVMFFIArray = 71
kTVMFFIMap = 72
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74


ctypedef void* TVMFFIObjectHandle

ctypedef struct TVMFFIObject:
int32_t type_index
int32_t ref_counter
Expand Down Expand Up @@ -219,9 +238,9 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":

int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil
void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil
int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream) nogil
int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream) nogil


cdef class ByteArrayArg:
Expand Down
29 changes: 22 additions & 7 deletions ffi/python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,25 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif hasattr(arg, "__dlpack__"):
arg = from_dlpack(arg)
ffi_arg = from_dlpack(arg)
out[i].type_index = kTVMFFITensor
out[i].v_ptr = (<Tensor>arg).chandle
temp_args.append(arg)
out[i].v_ptr = (<Tensor>ffi_arg).chandle
# record the stream from the source framework context when possible
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>ffi_arg).chandle)
if (temp_dltensor.device.device_type != kDLCPU and
ctx_dev_type != NULL and
ctx_dev_type[0] == -1):
# __tvm_ffi_env_stream__ returns the expected stream that should be set
# through TVMFFIEnvSetCurrentStream when calling a TVM FFI function
if hasattr(arg, "__tvm_ffi_env_stream__"):
# Ideally projects should directly setup their stream context API
# write through by also calling TVMFFIEnvSetCurrentStream
# so we do not need this protocol to do exchange
ctx_dev_type[0] = temp_dltensor.device.device_type
ctx_dev_id[0] = temp_dltensor.device.device_id
temp_ptr= arg.__tvm_ffi_env_stream__()
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(ffi_arg)
elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None:
arg = arg.__tvm_ffi_object__
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
Expand Down Expand Up @@ -210,7 +225,7 @@ cdef inline int FuncCall3(void* chandle,
with nogil:
if ctx_dev_type != -1:
# set the stream based on ctx stream
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
if c_api_ret_code[0] != 0:
return 0
c_api_ret_code[0] = TVMFFIFunctionCall(
Expand All @@ -219,7 +234,7 @@ cdef inline int FuncCall3(void* chandle,
# restore the original stream if it is not the same as the context stream
if ctx_dev_type != -1 and prev_stream != ctx_stream:
# restore the original stream
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
if c_api_ret_code[0] != 0:
return 0
return 0
Expand Down Expand Up @@ -247,13 +262,13 @@ cdef inline int FuncCall(void* chandle,

with nogil:
if ctx_dev_type != -1:
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, ctx_stream, &prev_stream)
if c_api_ret_code[0] != 0:
return 0
c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], nargs, result)
# restore the original stream if it is not the same as the context stream
if ctx_dev_type != -1 and prev_stream != ctx_stream:
c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, ctx_dev_id, prev_stream, NULL)
if c_api_ret_code[0] != 0:
return 0

Expand Down
24 changes: 24 additions & 0 deletions ffi/python/tvm_ffi/cython/tensor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,30 @@ _set_class_tensor(Tensor)
_register_object_by_index(kTVMFFITensor, Tensor)


cdef class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose.
"""
cdef Tensor tensor
def __init__(self, tensor):
self.tensor = tensor

def __tvm_ffi_env_stream__(self):
cdef TVMFFIStreamHandle stream
cdef long long stream_as_int
cdef int c_api_ret_code
with nogil:
stream = TVMFFIEnvGetCurrentStream(
self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id)
stream_as_int = <long long>stream
return stream_as_int

def __dlpack_device__(self):
return self.tensor.__dlpack_device__()

def __dlpack__(self, *, **kwargs):
return self.tensor.__dlpack__(**kwargs)


cdef inline object make_ret_dltensor(TVMFFIAny result):
cdef DLTensor* dltensor
dltensor = <DLTensor*>result.v_ptr
Expand Down
26 changes: 22 additions & 4 deletions ffi/scripts/benchmark_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@


def print_speed(name, speed):
print(f"{name:<40} {speed} sec/call")
print(f"{name:<60} {speed} sec/call")


def print_error(name, error):
print(f"{name:<40} {error}")
print(f"{name:<60} {error}")


def baseline_torch_add(repeat):
Expand Down Expand Up @@ -122,7 +122,7 @@ def tvm_ffi_nop(repeat):
nop(x, y, z)
start = time.time()
for i in range(repeat):
y = tvm_ffi.from_dlpack(x)
nop(x, y, z)
end = time.time()
print_speed("tvm_ffi.nop", (end - start) / repeat)

Expand Down Expand Up @@ -275,6 +275,22 @@ def tvm_ffi_nop_autodlpack_from_numpy(repeat):
bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z, repeat)


def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device):
"""
Measures overhead of running dlpack via auto convert by directly
take test wrapper as inputs. This effectively measure DLPack exchange in tvm ffi.
"""
x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
x = tvm_ffi.core.DLTensorTestWrapper(x)
y = tvm_ffi.core.DLTensorTestWrapper(y)
z = tvm_ffi.core.DLTensorTestWrapper(z)
bench_tvm_ffi_nop_autodlpack(
f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, repeat
)


def bench_to_dlpack(x, name, repeat):
x.__dlpack__()
start = time.time()
Expand Down Expand Up @@ -367,7 +383,6 @@ def main():
baseline_numpy_add(repeat)
baseline_torch_add(repeat)
baseline_cupy_add(repeat)
tvm_ffi_nop(repeat)
tvm_ffi_nop_from_torch_dlpack(repeat)
tvm_ffi_nop_from_numpy_dlpack(repeat)
tvm_ffi_self_dlpack_nop(repeat)
Expand All @@ -377,6 +392,9 @@ def main():
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)

tvm_ffi_nop_autodlpack_from_numpy(repeat)
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
tvm_ffi_nop(repeat)
print("-------------------------------")
print("Benchmark x.__dlpack__ overhead")
print("-------------------------------")
Expand Down
4 changes: 2 additions & 2 deletions ffi/src/ffi/extra/stream_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class StreamContext {
} // namespace ffi
} // namespace tvm

int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
TVMFFIStreamHandle* out_original_stream) {
int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
TVMFFIStreamHandle* out_original_stream) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream,
out_original_stream);
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }
void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}

void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr));
TVM_FFI_CHECK_SAFE_CALL(
TVMFFIEnvSetCurrentStream(dev.device_type, dev.device_id, stream, nullptr));
}

TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) {
Expand Down
7 changes: 4 additions & 3 deletions src/runtime/vm/cuda/cuda_graph_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,14 @@ class CUDACaptureStream {
explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) {
CUDA_CALL(cudaGetDevice(&device_id_));
TVM_FFI_CHECK_SAFE_CALL(
TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, capture_stream_,
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
}
~CUDACaptureStream() noexcept(false) {
cudaStreamEndCapture(capture_stream_, output_graph_);
TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr));
TVM_FFI_CHECK_SAFE_CALL(
TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, prev_default_stream_, nullptr));
}

private:
Expand Down
Loading