From 9b995aad74ca3aff9fba3619764665f9a2c72a69 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 21 Aug 2025 14:14:27 -0400 Subject: [PATCH] [Thrust] Fix getting CUDA stream This PR updates the `GetCUDAStream` in CUDA thrust integration to the latest `TVMFFIEnvGetCurrentStream` interface. --- src/runtime/contrib/thrust/thrust.cu | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 45c1bcc7cf3d..1adf95f69320 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -91,7 +92,10 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { }; auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) { - return thrust::cuda::par_nosync(memory_resouce).on(GetCUDAStream()); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + return thrust::cuda::par_nosync(memory_resouce).on(stream); } // Performs sorting along axis -1 and returns both sorted values and indices.