diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 45c1bcc7cf3d..1adf95f69320 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -91,7 +92,10 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { }; auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) { - return thrust::cuda::par_nosync(memory_resouce).on(GetCUDAStream()); + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + return thrust::cuda::par_nosync(memory_resouce).on(stream); } // Performs sorting along axis -1 and returns both sorted values and indices.