From 0894af79c297a374b51541da9af93c820e535851 Mon Sep 17 00:00:00 2001 From: Steffen Larsen Date: Wed, 1 Apr 2020 14:32:40 +0100 Subject: [PATCH] [SYCL][CUDA] Fixes active context when creating base event Signed-off-by: Steffen Larsen --- sycl/plugins/cuda/pi_cuda.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index 81e200d34710c..936bedad90912 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -1413,6 +1413,8 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties, std::unique_ptr<_pi_context> piContextPtr{nullptr}; try { + CUcontext current = nullptr; + if (property_cuda_primary) { // Use the CUDA primary context and assume that we want to use it // immediately as we want to forge context switches. @@ -1424,23 +1426,26 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties, errcode_ret = PI_CHECK_ERROR(cuCtxPushCurrent(Ctxt)); } else { // Create a scoped context. - CUcontext newContext, current; + CUcontext newContext; PI_CHECK_ERROR(cuCtxGetCurrent(¤t)); errcode_ret = PI_CHECK_ERROR( cuCtxCreate(&newContext, CU_CTX_MAP_HOST, devices[0]->get())); piContextPtr = std::unique_ptr<_pi_context>(new _pi_context{ _pi_context::kind::user_defined, newContext, *devices}); - // For scoped contexts keep the last active CUDA one on top of the stack - // as `cuCtxCreate` replaces it implicitly otherwise. - if (current != nullptr) { - PI_CHECK_ERROR(cuCtxSetCurrent(current)); - } } // Use default stream to record base event counter PI_CHECK_ERROR(cuEventCreate(&piContextPtr->evBase_, CU_EVENT_DEFAULT)); PI_CHECK_ERROR(cuEventRecord(piContextPtr->evBase_, 0)); + // For non-primary scoped contexts keep the last active on top of the stack + // as `cuCtxCreate` replaces it implicitly otherwise. + // Primary contexts are kept on top of the stack, so the previous context + // is not queried and therefore not recovered. + if (current != nullptr) { + PI_CHECK_ERROR(cuCtxSetCurrent(current)); + } + *retcontext = piContextPtr.release(); } catch (pi_result err) { errcode_ret = err;