diff --git a/sycl/plugins/cuda/CMakeLists.txt b/sycl/plugins/cuda/CMakeLists.txt index e0a4e67fe27c4..646568018a4e6 100644 --- a/sycl/plugins/cuda/CMakeLists.txt +++ b/sycl/plugins/cuda/CMakeLists.txt @@ -55,6 +55,8 @@ add_sycl_plugin(cuda "../unified_runtime/ur/ur.cpp" "../unified_runtime/ur/usm_allocator.cpp" "../unified_runtime/ur/usm_allocator.hpp" + "../unified_runtime/ur/adapters/cuda/adapter.cpp" + "../unified_runtime/ur/adapters/cuda/adapter.hpp" "../unified_runtime/ur/adapters/cuda/command_buffer.cpp" "../unified_runtime/ur/adapters/cuda/command_buffer.hpp" "../unified_runtime/ur/adapters/cuda/common.cpp" diff --git a/sycl/plugins/hip/CMakeLists.txt b/sycl/plugins/hip/CMakeLists.txt index a75de97f4b118..29ac21230f025 100644 --- a/sycl/plugins/hip/CMakeLists.txt +++ b/sycl/plugins/hip/CMakeLists.txt @@ -92,6 +92,8 @@ add_sycl_plugin(hip "../unified_runtime/pi2ur.cpp" "../unified_runtime/ur/ur.hpp" "../unified_runtime/ur/ur.cpp" + "../unified_runtime/ur/adapters/hip/adapter.cpp" + "../unified_runtime/ur/adapters/hip/adapter.hpp" "../unified_runtime/ur/adapters/hip/command_buffer.cpp" "../unified_runtime/ur/adapters/hip/command_buffer.hpp" "../unified_runtime/ur/adapters/hip/common.cpp" diff --git a/sycl/plugins/level_zero/CMakeLists.txt b/sycl/plugins/level_zero/CMakeLists.txt index bd24ffea80220..8e6cbce7a635a 100755 --- a/sycl/plugins/level_zero/CMakeLists.txt +++ b/sycl/plugins/level_zero/CMakeLists.txt @@ -103,6 +103,7 @@ add_sycl_plugin(level_zero "../unified_runtime/ur/usm_allocator_config.hpp" "../unified_runtime/ur/adapters/level_zero/ur_level_zero.hpp" "../unified_runtime/ur/adapters/level_zero/command_buffer.hpp" + "../unified_runtime/ur/adapters/level_zero/adapter.hpp" "../unified_runtime/ur/adapters/level_zero/common.hpp" "../unified_runtime/ur/adapters/level_zero/context.hpp" "../unified_runtime/ur/adapters/level_zero/device.hpp" @@ -116,6 +117,7 @@ add_sycl_plugin(level_zero "../unified_runtime/ur/adapters/level_zero/sampler.hpp" "../unified_runtime/ur/adapters/level_zero/usm.hpp" "../unified_runtime/ur/adapters/level_zero/ur_level_zero.cpp" + "../unified_runtime/ur/adapters/level_zero/adapter.cpp" "../unified_runtime/ur/adapters/level_zero/command_buffer.cpp" "../unified_runtime/ur/adapters/level_zero/common.cpp" "../unified_runtime/ur/adapters/level_zero/context.cpp" diff --git a/sycl/plugins/native_cpu/CMakeLists.txt b/sycl/plugins/native_cpu/CMakeLists.txt index 33e4c1ecb26d1..c4214563673ce 100644 --- a/sycl/plugins/native_cpu/CMakeLists.txt +++ b/sycl/plugins/native_cpu/CMakeLists.txt @@ -5,6 +5,7 @@ add_sycl_plugin(native_cpu "../unified_runtime/pi2ur.cpp" "../unified_runtime/ur/ur.hpp" "../unified_runtime/ur/ur.cpp" + "../unified_runtime/ur/adapters/native_cpu/adapter.cpp" "../unified_runtime/ur/adapters/native_cpu/common.cpp" "../unified_runtime/ur/adapters/native_cpu/common.hpp" "../unified_runtime/ur/adapters/native_cpu/context.cpp" @@ -24,7 +25,6 @@ add_sycl_plugin(native_cpu "../unified_runtime/ur/adapters/native_cpu/queue.cpp" "../unified_runtime/ur/adapters/native_cpu/queue.hpp" "../unified_runtime/ur/adapters/native_cpu/sampler.cpp" - "../unified_runtime/ur/adapters/native_cpu/runtime.cpp" "../unified_runtime/ur/adapters/native_cpu/ur_interface_loader.cpp" "../unified_runtime/ur/adapters/native_cpu/usm.cpp" "../unified_runtime/ur/adapters/native_cpu/usm_p2p.cpp" diff --git a/sycl/plugins/unified_runtime/CMakeLists.txt b/sycl/plugins/unified_runtime/CMakeLists.txt index 69ab0e5bc56a8..f0ff9fec8eb4b 100755 --- a/sycl/plugins/unified_runtime/CMakeLists.txt +++ b/sycl/plugins/unified_runtime/CMakeLists.txt @@ -4,7 +4,7 @@ if (NOT DEFINED UNIFIED_RUNTIME_LIBRARY OR NOT DEFINED UNIFIED_RUNTIME_INCLUDE_D include(FetchContent) set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git") - set(UNIFIED_RUNTIME_TAG 3c6f02c7a76a0448a83932d93c2dbeff25af70aa) + set(UNIFIED_RUNTIME_TAG 974a7d64dd1a26ede1ff27919b3b8713b848c376) message(STATUS "Will fetch Unified Runtime from ${UNIFIED_RUNTIME_REPO}") FetchContent_Declare(unified-runtime @@ -86,6 +86,7 @@ add_sycl_library("ur_adapter_level_zero" SHARED "ur/adapters/level_zero/ur_level_zero.hpp" "ur/adapters/level_zero/ur_level_zero.cpp" "ur/adapters/level_zero/ur_interface_loader.cpp" + "ur/adapters/level_zero/adapter.hpp" "ur/adapters/level_zero/command_buffer.hpp" "ur/adapters/level_zero/common.hpp" "ur/adapters/level_zero/context.hpp" @@ -100,6 +101,7 @@ add_sycl_library("ur_adapter_level_zero" SHARED "ur/adapters/level_zero/queue.hpp" "ur/adapters/level_zero/sampler.hpp" "ur/adapters/level_zero/usm.hpp" + "ur/adapters/level_zero/adapter.cpp" "ur/adapters/level_zero/command_buffer.cpp" "ur/adapters/level_zero/common.cpp" "ur/adapters/level_zero/context.cpp" @@ -135,6 +137,8 @@ if ("cuda" IN_LIST SYCL_ENABLE_PLUGINS) "ur/ur.cpp" "ur/usm_allocator.cpp" "ur/usm_allocator.hpp" + "ur/adapters/cuda/adapter.cpp" + "ur/adapters/cuda/adapter.hpp" "ur/adapters/cuda/command_buffer.cpp" "ur/adapters/cuda/command_buffer.hpp" "ur/adapters/cuda/common.cpp" @@ -186,6 +190,8 @@ if ("hip" IN_LIST SYCL_ENABLE_PLUGINS) "ur/ur.cpp" "ur/usm_allocator.cpp" "ur/usm_allocator.hpp" + "ur/adapters/hip/adapter.cpp" + "ur/adapters/hip/adapter.hpp" "ur/adapters/hip/command_buffer.cpp" "ur/adapters/hip/command_buffer.hpp" "ur/adapters/hip/common.cpp" @@ -243,6 +249,7 @@ if("native_cpu" IN_LIST SYCL_ENABLE_PLUGINS) SOURCES "ur/ur.cpp" "ur/ur.hpp" + "ur/adapters/native_cpu/adapter.cpp" "ur/adapters/native_cpu/common.cpp" "ur/adapters/native_cpu/common.hpp" "ur/adapters/native_cpu/context.cpp" @@ -262,7 +269,6 @@ if("native_cpu" IN_LIST SYCL_ENABLE_PLUGINS) "ur/adapters/native_cpu/queue.cpp" "ur/adapters/native_cpu/queue.hpp" "ur/adapters/native_cpu/sampler.cpp" - "ur/adapters/native_cpu/runtime.cpp" "ur/adapters/native_cpu/ur_interface_loader.cpp" "ur/adapters/native_cpu/usm.cpp" "ur/adapters/native_cpu/usm_p2p.cpp" diff --git a/sycl/plugins/unified_runtime/pi2ur.hpp b/sycl/plugins/unified_runtime/pi2ur.hpp index f0d6846792a92..e9f87f0a77fc8 100644 --- a/sycl/plugins/unified_runtime/pi2ur.hpp +++ b/sycl/plugins/unified_runtime/pi2ur.hpp @@ -719,6 +719,22 @@ namespace pi2ur { inline pi_result piTearDown(void *PluginParameter) { std::ignore = PluginParameter; + // Fetch the single known adapter (the one which is statically linked) so we + // can release it. Fetching it for a second time (after piPlatformsGet) + // increases the reference count, so we need to release it twice. + // pi_unified_runtime has its own implementation of piTearDown. + static std::once_flag AdapterReleaseFlag; + ur_adapter_handle_t Adapter; + ur_result_t Ret = UR_RESULT_SUCCESS; + std::call_once(AdapterReleaseFlag, [&]() { + Ret = urAdapterGet(1, &Adapter, nullptr); + if (Ret == UR_RESULT_SUCCESS) { + Ret = urAdapterRelease(Adapter); + Ret = urAdapterRelease(Adapter); + } + }); + HANDLE_ERRORS(Ret); + // TODO: Dont check for errors in urTearDown, since // when using Level Zero plugin, the second urTearDown // will fail as ur_loader.so has already been unloaded, @@ -731,9 +747,20 @@ inline pi_result piTearDown(void *PluginParameter) { inline pi_result piPlatformsGet(pi_uint32 NumEntries, pi_platform *Platforms, pi_uint32 *NumPlatforms) { - urInit(0); + urInit(0, nullptr); + // We're not going through the UR loader so we're guaranteed to have exactly + // one adapter (whichever is statically linked). The PI plugin for UR has its + // own implementation of piPlatformsGet. + static ur_adapter_handle_t Adapter; + static std::once_flag AdapterGetFlag; + ur_result_t Ret = UR_RESULT_SUCCESS; + std::call_once(AdapterGetFlag, + [&Ret]() { Ret = urAdapterGet(1, &Adapter, nullptr); }); + HANDLE_ERRORS(Ret); + auto phPlatforms = reinterpret_cast(Platforms); - HANDLE_ERRORS(urPlatformGet(NumEntries, phPlatforms, NumPlatforms)); + HANDLE_ERRORS( + urPlatformGet(&Adapter, 1, NumEntries, phPlatforms, NumPlatforms)); return PI_SUCCESS; } @@ -894,8 +921,18 @@ inline pi_result piDeviceRelease(pi_device Device) { return PI_SUCCESS; } -inline pi_result piPluginGetLastError(char **message) { - std::ignore = message; +inline pi_result piPluginGetLastError(char **Message) { + // We're not going through the UR loader so we're guaranteed to have exactly + // one adapter (whichever is statically linked). The PI plugin for UR has its + // own implementation of piPluginGetLastError. Materialize the adapter + // reference for the urAdapterGetLastError call, then release it. + ur_adapter_handle_t Adapter; + urAdapterGet(1, &Adapter, nullptr); + int32_t ErrorCode; + urAdapterGetLastError(Adapter, const_cast(Message), + &ErrorCode); + urAdapterRelease(Adapter); + return PI_SUCCESS; } diff --git a/sycl/plugins/unified_runtime/pi_unified_runtime.cpp b/sycl/plugins/unified_runtime/pi_unified_runtime.cpp index aeb14bdfa9e7d..8df43edb111a1 100644 --- a/sycl/plugins/unified_runtime/pi_unified_runtime.cpp +++ b/sycl/plugins/unified_runtime/pi_unified_runtime.cpp @@ -17,12 +17,34 @@ static void DieUnsupported() { die("Unified Runtime: functionality is not supported"); } +// Adapters may be released by piTearDown being called, or the global dtors +// being called first. Handle releasing the adapters exactly once. +static void releaseAdapters(std::vector &Vec) { + static std::once_flag ReleaseFlag{}; + std::call_once(ReleaseFlag, [&]() { + for (auto Adapter : Vec) { + urAdapterRelease(Adapter); + } + urTearDown(nullptr); + }); +} + +struct AdapterHolder { + ~AdapterHolder() { releaseAdapters(Vec); } + std::vector Vec{}; +} Adapters; + // All PI API interfaces are C interfaces extern "C" { __SYCL_EXPORT pi_result piPlatformsGet(pi_uint32 NumEntries, pi_platform *Platforms, pi_uint32 *NumPlatforms) { - return pi2ur::piPlatformsGet(NumEntries, Platforms, NumPlatforms); + // Get all the platforms from all available adapters + urPlatformGet(Adapters.Vec.data(), static_cast(Adapters.Vec.size()), + NumEntries, reinterpret_cast(Platforms), + NumPlatforms); + + return PI_SUCCESS; } __SYCL_EXPORT pi_result piPlatformGetInfo(pi_platform Platform, @@ -1122,6 +1144,12 @@ __SYCL_EXPORT pi_result piextPeerAccessGetInfo( ParamValueSizeRet); } +__SYCL_EXPORT pi_result piTearDown(void *) { + releaseAdapters(Adapters.Vec); + urTearDown(nullptr); + return PI_SUCCESS; +} + __SYCL_EXPORT pi_result piextMemImageAllocate(pi_context Context, pi_device Device, pi_image_format *ImageFormat, @@ -1256,11 +1284,6 @@ __SYCL_EXPORT pi_result piextSignalExternalSemaphore( Queue, SemHandle, NumEventsInWaitList, EventWaitList, Event); } -// This interface is not in Unified Runtime currently -__SYCL_EXPORT pi_result piTearDown(void *PluginParameter) { - return pi2ur::piTearDown(PluginParameter); -} - // This interface is not in Unified Runtime currently __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) { PI_ASSERT(PluginInit, PI_ERROR_INVALID_VALUE); @@ -1279,6 +1302,15 @@ __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) { strncpy(PluginInit->PluginVersion, SupportedVersion, PluginVersionSize); + // Initialize UR and discover adapters + HANDLE_ERRORS(urInit(0, nullptr)); + uint32_t NumAdapters; + HANDLE_ERRORS(urAdapterGet(0, nullptr, &NumAdapters)); + if (NumAdapters > 0) { + Adapters.Vec.resize(NumAdapters); + HANDLE_ERRORS(urAdapterGet(NumAdapters, Adapters.Vec.data(), nullptr)); + } + // Bind interfaces that are already supported and "die" for unsupported ones #define _PI_API(api) \ (PluginInit->PiFunctionTable).api = (decltype(&::api))(&DieUnsupported); diff --git a/sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.cpp b/sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.cpp new file mode 100644 index 0000000000000..402f542108b90 --- /dev/null +++ b/sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.cpp @@ -0,0 +1,89 @@ +//===--------- adapter.cpp - CUDA Adapter ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------------------===// + +#include + +#include "common.hpp" + +void enableCUDATracing(); +void disableCUDATracing(); + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 0; + std::mutex Mutex; +}; + +ur_adapter_handle_t_ adapter{}; + +UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t, + ur_loader_config_handle_t) { + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) { + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL +urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, + uint32_t *pNumAdapters) { + if (NumEntries > 0 && phAdapters) { + std::lock_guard Lock{adapter.Mutex}; + if (adapter.RefCount++ == 0) { + enableCUDATracing(); + } + + *phAdapters = &adapter; + } + + if (pNumAdapters) { + *pNumAdapters = 1; + } + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { + adapter.RefCount++; + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { + std::lock_guard Lock{adapter.Mutex}; + if (--adapter.RefCount == 0) { + disableCUDATracing(); + } + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError( + ur_adapter_handle_t, const char **ppMessage, int32_t *pError) { + *ppMessage = ErrorMessage; + *pError = ErrorMessageCode; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, + ur_adapter_info_t propName, + size_t propSize, + void *pPropValue, + size_t *pPropSizeRet) { + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + + switch (propName) { + case UR_ADAPTER_INFO_BACKEND: + return ReturnValue(UR_ADAPTER_BACKEND_CUDA); + case UR_ADAPTER_INFO_REFERENCE_COUNT: + return ReturnValue(adapter.RefCount.load()); + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + return UR_RESULT_SUCCESS; +} diff --git a/sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.hpp b/sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.hpp new file mode 100644 index 0000000000000..8c9814d600ddf --- /dev/null +++ b/sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.hpp @@ -0,0 +1,11 @@ +//===--------- adapter.cpp - CUDA Adapter ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------------------===// + +struct ur_adapter_handle_t_; + +extern ur_adapter_handle_t_ adapter; diff --git a/sycl/plugins/unified_runtime/ur/adapters/cuda/common.cpp b/sycl/plugins/unified_runtime/ur/adapters/cuda/common.cpp index 756b6ae52e4a3..0cafc390078ae 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/cuda/common.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/cuda/common.cpp @@ -134,10 +134,3 @@ void setPluginSpecificMessage(CUresult cu_res) { setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC); free(message); } - -// Returns plugin specific error and warning messages; common implementation -// that can be shared between adapters -ur_result_t urGetLastResult(ur_platform_handle_t, const char **ppMessage) { - *ppMessage = &ErrorMessage[0]; - return ErrorMessageCode; -} diff --git a/sycl/plugins/unified_runtime/ur/adapters/cuda/device.cpp b/sycl/plugins/unified_runtime/ur/adapters/cuda/device.cpp index 20e7c4c346240..b8350d4853342 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/cuda/device.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/cuda/device.cpp @@ -9,6 +9,7 @@ #include #include +#include "adapter.hpp" #include "context.hpp" #include "device.hpp" #include "platform.hpp" @@ -1206,13 +1207,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; - ur_result_t Result = urPlatformGet(0, nullptr, &NumPlatforms); + ur_adapter_handle_t AdapterHandle = &adapter; + ur_result_t Result = + urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) return Result; ur_platform_handle_t *Plat = static_cast( malloc(NumPlatforms * sizeof(ur_platform_handle_t))); - Result = urPlatformGet(NumPlatforms, Plat, nullptr); + Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, Plat, nullptr); if (Result != UR_RESULT_SUCCESS) return Result; diff --git a/sycl/plugins/unified_runtime/ur/adapters/cuda/platform.cpp b/sycl/plugins/unified_runtime/ur/adapters/cuda/platform.cpp index 410b0436d5233..a783f83558bcd 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/cuda/platform.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/cuda/platform.cpp @@ -15,9 +15,6 @@ #include #include -void enableCUDATracing(); -void disableCUDATracing(); - UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( ur_platform_handle_t hPlatform, ur_platform_info_t PlatformInfoType, size_t Size, void *pPlatformInfo, size_t *pSizeRet) { @@ -57,8 +54,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( /// However because multiple devices in a context is not currently supported, /// place each device in a separate platform. UR_APIEXPORT ur_result_t UR_APICALL -urPlatformGet(uint32_t NumEntries, ur_platform_handle_t *phPlatforms, - uint32_t *pNumPlatforms) { +urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, + ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { try { static std::once_flag InitFlag; @@ -188,16 +185,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } -UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t) { - enableCUDATracing(); - return UR_RESULT_SUCCESS; -} - -UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) { - disableCUDATracing(); - return UR_RESULT_SUCCESS; -} - // Get CUDA plugin specific backend option. // Current support is only for optimization options. // Return empty string for cuda. diff --git a/sycl/plugins/unified_runtime/ur/adapters/cuda/ur_interface_loader.cpp b/sycl/plugins/unified_runtime/ur/adapters/cuda/ur_interface_loader.cpp index 119bde5955f5c..f1e7b834e4633 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/cuda/ur_interface_loader.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/cuda/ur_interface_loader.cpp @@ -202,6 +202,12 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetGlobalProcAddrTable( } pDdiTable->pfnInit = urInit; pDdiTable->pfnTearDown = urTearDown; + pDdiTable->pfnAdapterGet = urAdapterGet; + pDdiTable->pfnAdapterRelease = urAdapterRelease; + pDdiTable->pfnAdapterRetain = urAdapterRetain; + pDdiTable->pfnAdapterGetLastError = urAdapterGetLastError; + pDdiTable->pfnAdapterGetInfo = urAdapterGetInfo; + return UR_RESULT_SUCCESS; } diff --git a/sycl/plugins/unified_runtime/ur/adapters/cuda/usm.cpp b/sycl/plugins/unified_runtime/ur/adapters/cuda/usm.cpp index 48b7166ef2979..167a6bca22b03 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/cuda/usm.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/cuda/usm.cpp @@ -8,6 +8,7 @@ #include +#include "adapter.hpp" #include "common.hpp" #include "context.hpp" #include "device.hpp" @@ -204,7 +205,9 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // the same index std::vector Platforms; Platforms.resize(DeviceIndex + 1); - Result = urPlatformGet(DeviceIndex + 1, Platforms.data(), nullptr); + ur_adapter_handle_t AdapterHandle = &adapter; + Result = urPlatformGet(&AdapterHandle, 1, DeviceIndex + 1, + Platforms.data(), nullptr); // get the device from the platform ur_device_handle_t Device = Platforms[DeviceIndex]->Devices[0].get(); diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/adapter.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/adapter.cpp new file mode 100644 index 0000000000000..54deaaec9bf95 --- /dev/null +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/adapter.cpp @@ -0,0 +1,78 @@ +//===--------- adapter.cpp - HIP Adapter -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------------------===// + +#include "adapter.hpp" +#include "common.hpp" + +#include +#include + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 0; +}; + +ur_adapter_handle_t_ adapter{}; + +UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t, + ur_loader_config_handle_t) { + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) { + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( + uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { + if (phAdapters) { + adapter.RefCount++; + *phAdapters = &adapter; + } + if (pNumAdapters) { + *pNumAdapters = 1; + } + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { + // No state to clean up so we don't need to check for 0 references + adapter.RefCount--; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { + adapter.RefCount++; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError( + ur_adapter_handle_t, const char **ppMessage, int32_t *pError) { + *ppMessage = ErrorMessage; + *pError = ErrorMessageCode; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, + ur_adapter_info_t propName, + size_t propSize, + void *pPropValue, + size_t *pPropSizeRet) { + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + + switch (propName) { + case UR_ADAPTER_INFO_BACKEND: + return ReturnValue(UR_ADAPTER_BACKEND_HIP); + case UR_ADAPTER_INFO_REFERENCE_COUNT: + return ReturnValue(adapter.RefCount.load()); + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + return UR_RESULT_SUCCESS; +} diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/adapter.hpp b/sycl/plugins/unified_runtime/ur/adapters/hip/adapter.hpp new file mode 100644 index 0000000000000..84753b79f6b50 --- /dev/null +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/adapter.hpp @@ -0,0 +1,11 @@ +//===--------- adapter.hpp - HIP Adapter -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------------------===// + +struct ur_adapter_handle_t_; + +extern ur_adapter_handle_t_ adapter; diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/platform.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/platform.cpp index 22a55505b1c86..fe773bff8794c 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/platform.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/platform.cpp @@ -48,8 +48,8 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName, /// However because multiple devices in a context is not currently supported, /// place each device in a separate platform. UR_APIEXPORT ur_result_t UR_APICALL -urPlatformGet(uint32_t NumEntries, ur_platform_handle_t *phPlatforms, - uint32_t *pNumPlatforms) { +urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, + ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { try { static std::once_flag InitFlag; @@ -129,14 +129,6 @@ urPlatformGetApiVersion(ur_platform_handle_t, ur_api_version_t *pVersion) { return UR_RESULT_SUCCESS; } -UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t) { - return UR_RESULT_SUCCESS; -} - -UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) { - return UR_RESULT_SUCCESS; -} - UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetNativeHandle( ur_platform_handle_t hPlatform, ur_native_handle_t *phNativePlatform) { std::ignore = hPlatform; diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/ur_interface_loader.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/ur_interface_loader.cpp index 580b9916fb485..3814b6459a2f4 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/ur_interface_loader.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/ur_interface_loader.cpp @@ -203,6 +203,12 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetGlobalProcAddrTable( pDdiTable->pfnInit = urInit; pDdiTable->pfnTearDown = urTearDown; + pDdiTable->pfnAdapterGet = urAdapterGet; + pDdiTable->pfnAdapterGetInfo = urAdapterGetInfo; + pDdiTable->pfnAdapterGetLastError = urAdapterGetLastError; + pDdiTable->pfnAdapterRelease = urAdapterRelease; + pDdiTable->pfnAdapterRetain = urAdapterRetain; + return UR_RESULT_SUCCESS; } diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp index 03a4ff18d7f5b..296954268a818 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp @@ -8,6 +8,7 @@ #include +#include "adapter.hpp" #include "common.hpp" #include "context.hpp" #include "device.hpp" @@ -174,7 +175,9 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // the same index std::vector Platforms; Platforms.resize(DeviceIdx + 1); - Result = urPlatformGet(DeviceIdx + 1, Platforms.data(), nullptr); + ur_adapter_handle_t AdapterHandle = &adapter; + Result = urPlatformGet(&AdapterHandle, 1, DeviceIdx + 1, Platforms.data(), + nullptr); // get the device from the platform ur_device_handle_t Device = Platforms[DeviceIdx]->Devices[0].get(); diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/adapter.cpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/adapter.cpp new file mode 100644 index 0000000000000..453b8b5ee50ab --- /dev/null +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/adapter.cpp @@ -0,0 +1,206 @@ +//===--------- adapter.cpp - Level Zero Adapter ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------------------===// + +#include "adapter.hpp" +#include "ur_level_zero.hpp" + +ur_adapter_handle_t_ Adapter{}; + +UR_APIEXPORT ur_result_t UR_APICALL +urInit(ur_device_init_flags_t + DeviceFlags, ///< [in] device initialization flags. + ///< must be 0 (default) or a combination of + ///< ::ur_device_init_flag_t. + ur_loader_config_handle_t) { + std::ignore = DeviceFlags; + + return UR_RESULT_SUCCESS; +} + +ur_result_t adapterStateTeardown() { + // reclaim ur_platform_handle_t objects here since we don't have + // urPlatformRelease. + for (ur_platform_handle_t Platform : *URPlatformsCache) { + delete Platform; + } + delete URPlatformsCache; + delete URPlatformsCacheMutex; + + bool LeakFound = false; + + // Print the balance of various create/destroy native calls. + // The idea is to verify if the number of create(+) and destroy(-) calls are + // matched. + if (ZeCallCount && (UrL0Debug & UR_L0_DEBUG_CALL_COUNT) != 0) { + // clang-format off + // + // The format of this table is such that each row accounts for a + // specific type of objects, and all elements in the raw except the last + // one are allocating objects of that type, while the last element is known + // to deallocate objects of that type. + // + std::vector> CreateDestroySet = { + {"zeContextCreate", "zeContextDestroy"}, + {"zeCommandQueueCreate", "zeCommandQueueDestroy"}, + {"zeModuleCreate", "zeModuleDestroy"}, + {"zeKernelCreate", "zeKernelDestroy"}, + {"zeEventPoolCreate", "zeEventPoolDestroy"}, + {"zeCommandListCreateImmediate", "zeCommandListCreate", "zeCommandListDestroy"}, + {"zeEventCreate", "zeEventDestroy"}, + {"zeFenceCreate", "zeFenceDestroy"}, + {"zeImageCreate", "zeImageDestroy"}, + {"zeSamplerCreate", "zeSamplerDestroy"}, + {"zeMemAllocDevice", "zeMemAllocHost", "zeMemAllocShared", "zeMemFree"}, + }; + + // A sample output aimed below is this: + // ------------------------------------------------------------------------ + // zeContextCreate = 1 \---> zeContextDestroy = 1 + // zeCommandQueueCreate = 1 \---> zeCommandQueueDestroy = 1 + // zeModuleCreate = 1 \---> zeModuleDestroy = 1 + // zeKernelCreate = 1 \---> zeKernelDestroy = 1 + // zeEventPoolCreate = 1 \---> zeEventPoolDestroy = 1 + // zeCommandListCreateImmediate = 1 | + // zeCommandListCreate = 1 \---> zeCommandListDestroy = 1 ---> LEAK = 1 + // zeEventCreate = 2 \---> zeEventDestroy = 2 + // zeFenceCreate = 1 \---> zeFenceDestroy = 1 + // zeImageCreate = 0 \---> zeImageDestroy = 0 + // zeSamplerCreate = 0 \---> zeSamplerDestroy = 0 + // zeMemAllocDevice = 0 | + // zeMemAllocHost = 1 | + // zeMemAllocShared = 0 \---> zeMemFree = 1 + // + // clang-format on + + fprintf(stderr, "ZE_DEBUG=%d: check balance of create/destroy calls\n", + UR_L0_DEBUG_CALL_COUNT); + fprintf(stderr, + "----------------------------------------------------------\n"); + for (const auto &Row : CreateDestroySet) { + int diff = 0; + for (auto I = Row.begin(); I != Row.end();) { + const char *ZeName = (*I).c_str(); + const auto &ZeCount = (*ZeCallCount)[*I]; + + bool First = (I == Row.begin()); + bool Last = (++I == Row.end()); + + if (Last) { + fprintf(stderr, " \\--->"); + diff -= ZeCount; + } else { + diff += ZeCount; + if (!First) { + fprintf(stderr, " | \n"); + } + } + + fprintf(stderr, "%30s = %-5d", ZeName, ZeCount); + } + + if (diff) { + LeakFound = true; + fprintf(stderr, " ---> LEAK = %d", diff); + } + fprintf(stderr, "\n"); + } + + ZeCallCount->clear(); + delete ZeCallCount; + ZeCallCount = nullptr; + } + if (LeakFound) + return UR_RESULT_ERROR_INVALID_MEM_OBJECT; + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urTearDown( + void *Params ///< [in] pointer to tear down parameters +) { + std::ignore = Params; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( + uint32_t NumEntries, ///< [in] the number of platforms to be added to + ///< phAdapters. If phAdapters is not NULL, then + ///< NumEntries should be greater than zero, otherwise + ///< ::UR_RESULT_ERROR_INVALID_SIZE, will be returned. + ur_adapter_handle_t + *Adapters, ///< [out][optional][range(0, NumEntries)] array of handle of + ///< adapters. If NumEntries is less than the number of + ///< adapters available, then + ///< ::urAdapterGet shall only retrieve that number of + ///< platforms. + uint32_t *NumAdapters ///< [out][optional] returns the total number of + ///< adapters available. +) { + if (NumEntries > 0 && Adapters) { + std::lock_guard Lock{Adapter.Mutex}; + // TODO: Some initialization that happens in urPlatformsGet could be moved + // here for when RefCount reaches 1 + Adapter.RefCount++; + *Adapters = &Adapter; + } + + if (NumAdapters) { + *NumAdapters = 1; + } + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { + std::lock_guard Lock{Adapter.Mutex}; + if (--Adapter.RefCount == 0) { + adapterStateTeardown(); + } + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { + std::lock_guard Lock{Adapter.Mutex}; + Adapter.RefCount++; + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError( + ur_adapter_handle_t Adapter, ///< [in] handle of the platform instance + const char **Message, ///< [out] pointer to a C string where the adapter + ///< specific error message will be stored. + int32_t *Error ///< [out] pointer to an integer where the adapter specific + ///< error code will be stored. +) { + std::ignore = Adapter; + std::ignore = Message; + std::ignore = Error; + urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, + ur_adapter_info_t PropName, + size_t PropSize, + void *PropValue, + size_t *PropSizeRet) { + UrReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); + + switch (PropName) { + case UR_ADAPTER_INFO_BACKEND: + return ReturnValue(UR_ADAPTER_BACKEND_LEVEL_ZERO); + case UR_ADAPTER_INFO_REFERENCE_COUNT: + return ReturnValue(Adapter.RefCount.load()); + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + return UR_RESULT_SUCCESS; +} diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/adapter.hpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/adapter.hpp new file mode 100644 index 0000000000000..1b04a8555efb5 --- /dev/null +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/adapter.hpp @@ -0,0 +1,17 @@ +//===--------- adapters.hpp - Level Zero Adapter ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------------------===// + +#include +#include + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 0; + std::mutex Mutex; +}; + +extern ur_adapter_handle_t_ Adapter; diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/platform.cpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/platform.cpp index 2011e2f5c4e2d..86dbeceea5817 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/level_zero/platform.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/platform.cpp @@ -7,121 +7,11 @@ //===-----------------------------------------------------------------===// #include "platform.hpp" +#include "adapter.hpp" #include "ur_level_zero.hpp" -UR_APIEXPORT ur_result_t UR_APICALL urInit( - ur_device_init_flags_t - DeviceFlags ///< [in] device initialization flags. - ///< must be 0 (default) or a combination of - ///< ::ur_device_init_flag_t. -) { - std::ignore = DeviceFlags; - - return UR_RESULT_SUCCESS; -} - -UR_APIEXPORT ur_result_t UR_APICALL urTearDown( - void *Params ///< [in] pointer to tear down parameters -) { - std::ignore = Params; - // reclaim ur_platform_handle_t objects here since we don't have - // urPlatformRelease. - for (ur_platform_handle_t Platform : *URPlatformsCache) { - delete Platform; - } - delete URPlatformsCache; - delete URPlatformsCacheMutex; - - bool LeakFound = false; - - // Print the balance of various create/destroy native calls. - // The idea is to verify if the number of create(+) and destroy(-) calls are - // matched. - if (ZeCallCount && (UrL0Debug & UR_L0_DEBUG_CALL_COUNT) != 0) { - // clang-format off - // - // The format of this table is such that each row accounts for a - // specific type of objects, and all elements in the raw except the last - // one are allocating objects of that type, while the last element is known - // to deallocate objects of that type. - // - std::vector> CreateDestroySet = { - {"zeContextCreate", "zeContextDestroy"}, - {"zeCommandQueueCreate", "zeCommandQueueDestroy"}, - {"zeModuleCreate", "zeModuleDestroy"}, - {"zeKernelCreate", "zeKernelDestroy"}, - {"zeEventPoolCreate", "zeEventPoolDestroy"}, - {"zeCommandListCreateImmediate", "zeCommandListCreate", "zeCommandListDestroy"}, - {"zeEventCreate", "zeEventDestroy"}, - {"zeFenceCreate", "zeFenceDestroy"}, - {"zeImageCreate", "zeImageDestroy"}, - {"zeSamplerCreate", "zeSamplerDestroy"}, - {"zeMemAllocDevice", "zeMemAllocHost", "zeMemAllocShared", "zeMemFree"}, - }; - - // A sample output aimed below is this: - // ------------------------------------------------------------------------ - // zeContextCreate = 1 \---> zeContextDestroy = 1 - // zeCommandQueueCreate = 1 \---> zeCommandQueueDestroy = 1 - // zeModuleCreate = 1 \---> zeModuleDestroy = 1 - // zeKernelCreate = 1 \---> zeKernelDestroy = 1 - // zeEventPoolCreate = 1 \---> zeEventPoolDestroy = 1 - // zeCommandListCreateImmediate = 1 | - // zeCommandListCreate = 1 \---> zeCommandListDestroy = 1 ---> LEAK = 1 - // zeEventCreate = 2 \---> zeEventDestroy = 2 - // zeFenceCreate = 1 \---> zeFenceDestroy = 1 - // zeImageCreate = 0 \---> zeImageDestroy = 0 - // zeSamplerCreate = 0 \---> zeSamplerDestroy = 0 - // zeMemAllocDevice = 0 | - // zeMemAllocHost = 1 | - // zeMemAllocShared = 0 \---> zeMemFree = 1 - // - // clang-format on - - fprintf(stderr, "ZE_DEBUG=%d: check balance of create/destroy calls\n", - UR_L0_DEBUG_CALL_COUNT); - fprintf(stderr, - "----------------------------------------------------------\n"); - for (const auto &Row : CreateDestroySet) { - int diff = 0; - for (auto I = Row.begin(); I != Row.end();) { - const char *ZeName = (*I).c_str(); - const auto &ZeCount = (*ZeCallCount)[*I]; - - bool First = (I == Row.begin()); - bool Last = (++I == Row.end()); - - if (Last) { - fprintf(stderr, " \\--->"); - diff -= ZeCount; - } else { - diff += ZeCount; - if (!First) { - fprintf(stderr, " | \n"); - } - } - - fprintf(stderr, "%30s = %-5d", ZeName, ZeCount); - } - - if (diff) { - LeakFound = true; - fprintf(stderr, " ---> LEAK = %d", diff); - } - fprintf(stderr, "\n"); - } - - ZeCallCount->clear(); - delete ZeCallCount; - ZeCallCount = nullptr; - } - if (LeakFound) - return UR_RESULT_ERROR_INVALID_MEM_OBJECT; - - return UR_RESULT_SUCCESS; -} - UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet( + ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, ///< [in] the number of platforms to be added to ///< phPlatforms. If phPlatforms is not NULL, then ///< NumEntries should be greater than zero, otherwise @@ -171,7 +61,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet( // Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms. if (ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) { - UR_ASSERT(NumEntries != 0, UR_RESULT_ERROR_INVALID_VALUE); + UR_ASSERT(NumEntries == 0, UR_RESULT_ERROR_INVALID_VALUE); if (NumPlatforms) *NumPlatforms = 0; return UR_RESULT_SUCCESS; @@ -322,11 +212,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( auto ZeDriver = ur_cast(NativePlatform); uint32_t NumPlatforms = 0; - UR_CALL(urPlatformGet(0, nullptr, &NumPlatforms)); + ur_adapter_handle_t AdapterHandle = &Adapter; + UR_CALL(urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms)); if (NumPlatforms) { std::vector Platforms(NumPlatforms); - UR_CALL(urPlatformGet(NumPlatforms, Platforms.data(), nullptr)); + UR_CALL(urPlatformGet(&AdapterHandle, 1, NumPlatforms, Platforms.data(), + nullptr)); // The SYCL spec requires that the set of platforms must remain fixed for // the duration of the application's execution. We assume that we found all @@ -344,20 +236,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( return UR_RESULT_ERROR_INVALID_VALUE; } -UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetLastError( - ur_platform_handle_t Platform, ///< [in] handle of the platform instance - const char **Message, ///< [out] pointer to a C string where the adapter - ///< specific error message will be stored. - int32_t *Error ///< [out] pointer to an integer where the adapter specific - ///< error code will be stored. -) { - std::ignore = Platform; - std::ignore = Message; - std::ignore = Error; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; -} - ur_result_t ur_platform_handle_t_::initialize() { // Cache driver properties ZeStruct ZeDriverProperties; diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/queue.cpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/queue.cpp index 3e5f6b607e22b..4b25609197742 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/level_zero/queue.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/queue.cpp @@ -11,6 +11,7 @@ #include #include +#include "adapter.hpp" #include "common.hpp" #include "queue.hpp" #include "ur_level_zero.hpp" @@ -568,7 +569,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle( // Maybe this is not completely correct. uint32_t NumEntries = 1; ur_platform_handle_t Platform{}; - UR_CALL(urPlatformGet(NumEntries, &Platform, nullptr)); + ur_adapter_handle_t AdapterHandle = &Adapter; + UR_CALL(urPlatformGet(&AdapterHandle, 1, NumEntries, &Platform, nullptr)); ur_device_handle_t UrDevice = Device; if (UrDevice == nullptr) { diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_interface_loader.cpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_interface_loader.cpp index 9c330b5b20bfb..ebb87114567b6 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_interface_loader.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_interface_loader.cpp @@ -33,6 +33,11 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetGlobalProcAddrTable( pDdiTable->pfnInit = urInit; pDdiTable->pfnTearDown = urTearDown; + pDdiTable->pfnAdapterGet = urAdapterGet; + pDdiTable->pfnAdapterRelease = urAdapterRelease; + pDdiTable->pfnAdapterRetain = urAdapterRetain; + pDdiTable->pfnAdapterGetLastError = urAdapterGetLastError; + pDdiTable->pfnAdapterGetInfo = urAdapterGetInfo; return retVal; } @@ -182,7 +187,6 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetPlatformProcAddrTable( pDdiTable->pfnCreateWithNativeHandle = urPlatformCreateWithNativeHandle; pDdiTable->pfnGetApiVersion = urPlatformGetApiVersion; pDdiTable->pfnGetBackendOption = urPlatformGetBackendOption; - pDdiTable->pfnGetLastError = urPlatformGetLastError; return retVal; } diff --git a/sycl/plugins/unified_runtime/ur/adapters/native_cpu/adapter.cpp b/sycl/plugins/unified_runtime/ur/adapters/native_cpu/adapter.cpp new file mode 100644 index 0000000000000..90d74093ad911 --- /dev/null +++ b/sycl/plugins/unified_runtime/ur/adapters/native_cpu/adapter.cpp @@ -0,0 +1,64 @@ +//===---------------- runtime.cpp - Native CPU Adapter --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "common.hpp" +#include "ur_api.h" + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 0; +} Adapter; + +UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t, + ur_loader_config_handle_t) { + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) { + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( + uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { + if (phAdapters) { + Adapter.RefCount++; + *phAdapters = &Adapter; + } + if (pNumAdapters) { + *pNumAdapters = 1; + } + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { + Adapter.RefCount--; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { + Adapter.RefCount++; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, + ur_adapter_info_t propName, + size_t propSize, + void *pPropValue, + size_t *pPropSizeRet) { + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + + switch (propName) { + case UR_ADAPTER_INFO_BACKEND: + return ReturnValue(UR_ADAPTER_BACKEND_NATIVE_CPU); + case UR_ADAPTER_INFO_REFERENCE_COUNT: + return ReturnValue(Adapter.RefCount.load()); + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + return UR_RESULT_SUCCESS; +} diff --git a/sycl/plugins/unified_runtime/ur/adapters/native_cpu/runtime.cpp b/sycl/plugins/unified_runtime/ur/adapters/native_cpu/runtime.cpp deleted file mode 100644 index 62d7caf6c8563..0000000000000 --- a/sycl/plugins/unified_runtime/ur/adapters/native_cpu/runtime.cpp +++ /dev/null @@ -1,17 +0,0 @@ -//===---------------- runtime.cpp - Native CPU Adapter --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "ur_api.h" - -UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t) { - return UR_RESULT_SUCCESS; -} - -UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) { - return UR_RESULT_SUCCESS; -} diff --git a/sycl/plugins/unified_runtime/ur/adapters/native_cpu/ur_interface_loader.cpp b/sycl/plugins/unified_runtime/ur/adapters/native_cpu/ur_interface_loader.cpp index a7c0cca576167..abc52daead372 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/native_cpu/ur_interface_loader.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/native_cpu/ur_interface_loader.cpp @@ -200,6 +200,10 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetGlobalProcAddrTable( } pDdiTable->pfnInit = urInit; pDdiTable->pfnTearDown = urTearDown; + pDdiTable->pfnAdapterGet = urAdapterGet; + pDdiTable->pfnAdapterGetInfo = urAdapterGetInfo; + pDdiTable->pfnAdapterRelease = urAdapterRelease; + pDdiTable->pfnAdapterRetain = urAdapterRetain; return UR_RESULT_SUCCESS; }