diff --git a/unified-runtime/source/adapters/cuda/usm.cpp b/unified-runtime/source/adapters/cuda/usm.cpp index c805c1084ec0f..1fdfa38cb71ec 100644 --- a/unified-runtime/source/adapters/cuda/usm.cpp +++ b/unified-runtime/source/adapters/cuda/usm.cpp @@ -574,10 +574,41 @@ urUSMPoolTrimToExp(ur_context_handle_t hContext, ur_device_handle_t hDevice, return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL urUSMContextMemcpyExp(ur_context_handle_t, - void *pDst, - const void *pSrc, - size_t Size) { +UR_APIEXPORT ur_result_t UR_APICALL urUSMContextMemcpyExp( + ur_context_handle_t hContext, void *pDst, const void *pSrc, size_t Size) { + + CUmemorytype memTypeDst; + UR_CHECK_ERROR(cuPointerGetAttribute( + &memTypeDst, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, (CUdeviceptr)pDst)); + + CUmemorytype memTypeSrc; + UR_CHECK_ERROR(cuPointerGetAttribute( + &memTypeSrc, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, (CUdeviceptr)pSrc)); + + if (memTypeDst != CU_MEMORYTYPE_DEVICE || + memTypeSrc != CU_MEMORYTYPE_DEVICE) { + UR_CHECK_ERROR(cuMemcpy((CUdeviceptr)pDst, (CUdeviceptr)pSrc, Size)); + return UR_RESULT_SUCCESS; + } + + // For device to device copy, we need to synchronize the host with the memcpy + // operation to ensure the copy is completed. + // For details, see: + // https://docs.nvidia.com/cuda/cuda-driver-api/api-sync-behavior.html. + unsigned int devIdx = 0; + UR_CHECK_ERROR(cuPointerGetAttribute( + &devIdx, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (CUdeviceptr)pDst)); + if (devIdx >= hContext->getDevices().size()) { + return UR_RESULT_ERROR_INVALID_CONTEXT; + } + + // Device ordinal obtained using cuPointerGetAttribute corresponds to the + // device index in the context's device list. + ur_device_handle_t owningDev = hContext->getDevices()[devIdx]; + + ScopedContext Active(owningDev); UR_CHECK_ERROR(cuMemcpy((CUdeviceptr)pDst, (CUdeviceptr)pSrc, Size)); + UR_CHECK_ERROR(cuCtxSynchronize()); + return UR_RESULT_SUCCESS; }