From 8d14c8a6b310896301a455c9699311d08968f932 Mon Sep 17 00:00:00 2001 From: Hugh Delaney Date: Thu, 7 Dec 2023 14:42:27 +0000 Subject: [PATCH] Use device to get native context SYCL contexts have a many to one mapping to native contexts. Therefore it is necessary to get the desired native context from a SYCL device, as SYCL devices have a one to one mapping to native contexts. --- src/blas/backends/cublas/cublas_scope_handle.cpp | 12 ++++++++---- src/blas/backends/rocblas/rocblas_scope_handle.cpp | 12 ++++++++---- src/dft/backends/cufft/commit.cpp | 7 ++++--- .../backends/cusolver/cusolver_scope_handle.cpp | 12 ++++++++---- .../backends/rocsolver/rocsolver_scope_handle.cpp | 11 ++++++++--- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index f3e39ca11..05d1c1935 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -42,10 +42,11 @@ CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl:: : ih(ih), needToRecover_(false) { placedContext_ = new sycl::context(queue.get_context()); - auto device = queue.get_device(); - auto desired = sycl::get_native(*placedContext_); + auto cudaDevice = ih.get_native_device(); CUresult err; + CUcontext desired; CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); @@ -87,8 +88,11 @@ void ContextCallback(void *userData) { } cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = reinterpret_cast( - sycl::get_native(*placedContext_)); + auto cudaDevice = ih.get_native_device(); + CUresult cuErr; + CUcontext desired; + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); + auto piPlacedContext_ = reinterpret_cast(desired); CUstream streamId = get_stream(queue); cublasStatus_t err; auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_); diff --git a/src/blas/backends/rocblas/rocblas_scope_handle.cpp b/src/blas/backends/rocblas/rocblas_scope_handle.cpp index 2abf0323b..404d1fc06 100644 --- a/src/blas/backends/rocblas/rocblas_scope_handle.cpp +++ b/src/blas/backends/rocblas/rocblas_scope_handle.cpp @@ -58,10 +58,11 @@ RocblasScopedContextHandler::RocblasScopedContextHandler(sycl::queue queue, : interop_h(ih), needToRecover_(false) { placedContext_ = new sycl::context(queue.get_context()); - auto device = queue.get_device(); - auto desired = sycl::get_native(*placedContext_); + auto hipDevice = ih.get_native_device(); hipError_t err; + hipCtx_t desired; HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_); + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired); @@ -103,8 +104,11 @@ void ContextCallback(void *userData) { } rocblas_handle RocblasScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = reinterpret_cast( - sycl::get_native(*placedContext_)); + auto hipDevice = interop_h.get_native_device(); + hipError_t hipErr; + hipCtx_t desired; + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice); + auto piPlacedContext_ = reinterpret_cast(desired); hipStream_t streamId = get_stream(queue); rocblas_status err; auto it = handle_helper.rocblas_handle_container_mapper_.find(piPlacedContext_); diff --git a/src/dft/backends/cufft/commit.cpp b/src/dft/backends/cufft/commit.cpp index 627f1b565..19507d722 100644 --- a/src/dft/backends/cufft/commit.cpp +++ b/src/dft/backends/cufft/commit.cpp @@ -81,9 +81,10 @@ class cufft_commit final : public dft::detail::commit_impl { } if (fix_context) { // cufftDestroy changes the context so change it back. - CUcontext interopContext = - sycl::get_native(this->get_queue().get_context()); - if (cuCtxSetCurrent(interopContext) != CUDA_SUCCESS) { + CUdevice interopDevice = + sycl::get_native(this->get_queue().get_device()); + CUcontext interopContext; + if (cuDevicePrimaryCtxRetain(&interopContext, interopDevice) != CUDA_SUCCESS) { throw mkl::exception("dft/backends/cufft", __FUNCTION__, "Failed to change cuda context."); } diff --git a/src/lapack/backends/cusolver/cusolver_scope_handle.cpp b/src/lapack/backends/cusolver/cusolver_scope_handle.cpp index 98fbca125..0bc3ebdb0 100644 --- a/src/lapack/backends/cusolver/cusolver_scope_handle.cpp +++ b/src/lapack/backends/cusolver/cusolver_scope_handle.cpp @@ -43,10 +43,11 @@ CusolverScopedContextHandler::CusolverScopedContextHandler(sycl::queue queue, : ih(ih), needToRecover_(false) { placedContext_ = new sycl::context(queue.get_context()); - auto device = queue.get_device(); - auto desired = sycl::get_native(*placedContext_); + auto cudaDevice = ih.get_native_device(); CUresult err; + CUcontext desired; CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); @@ -88,8 +89,11 @@ void ContextCallback(void *userData) { } cusolverDnHandle_t CusolverScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = reinterpret_cast( - sycl::get_native(*placedContext_)); + auto cudaDevice = ih.get_native_device(); + CUresult cuErr; + CUcontext desired; + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); + auto piPlacedContext_ = reinterpret_cast(desired); CUstream streamId = get_stream(queue); cusolverStatus_t err; auto it = handle_helper.cusolver_handle_mapper_.find(piPlacedContext_); diff --git a/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp b/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp index c50d6a5b4..42e262e7b 100644 --- a/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp +++ b/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp @@ -45,9 +45,11 @@ RocsolverScopedContextHandler::RocsolverScopedContextHandler(sycl::queue queue, : ih(ih), needToRecover_(false) { placedContext_ = new sycl::context(queue.get_context()); - auto desired = sycl::get_native(*placedContext_); + auto hipDevice = ih.get_native_device(); hipError_t err; + hipCtx_t desired; HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_); + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired); @@ -89,8 +91,11 @@ void ContextCallback(void *userData) { } rocblas_handle RocsolverScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = reinterpret_cast( - sycl::get_native(*placedContext_)); + auto hipDevice = ih.get_native_device(); + hipError_t hipErr; + hipCtx_t desired; + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice); + auto piPlacedContext_ = reinterpret_cast(desired); hipStream_t streamId = get_stream(queue); rocblas_status err; auto it = handle_helper.rocsolver_handle_mapper_.find(piPlacedContext_);