Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sycl/plugins/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ add_sycl_plugin(cuda
"../unified_runtime/ur/ur.cpp"
"../unified_runtime/ur/usm_allocator.cpp"
"../unified_runtime/ur/usm_allocator.hpp"
"../unified_runtime/ur/adapters/cuda/adapter.cpp"
"../unified_runtime/ur/adapters/cuda/adapter.hpp"
"../unified_runtime/ur/adapters/cuda/command_buffer.cpp"
"../unified_runtime/ur/adapters/cuda/command_buffer.hpp"
"../unified_runtime/ur/adapters/cuda/common.cpp"
Expand Down
2 changes: 2 additions & 0 deletions sycl/plugins/hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ add_sycl_plugin(hip
"../unified_runtime/pi2ur.cpp"
"../unified_runtime/ur/ur.hpp"
"../unified_runtime/ur/ur.cpp"
"../unified_runtime/ur/adapters/hip/adapter.cpp"
"../unified_runtime/ur/adapters/hip/adapter.hpp"
"../unified_runtime/ur/adapters/hip/command_buffer.cpp"
"../unified_runtime/ur/adapters/hip/command_buffer.hpp"
"../unified_runtime/ur/adapters/hip/common.cpp"
Expand Down
2 changes: 2 additions & 0 deletions sycl/plugins/level_zero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ add_sycl_plugin(level_zero
"../unified_runtime/ur/usm_allocator_config.hpp"
"../unified_runtime/ur/adapters/level_zero/ur_level_zero.hpp"
"../unified_runtime/ur/adapters/level_zero/command_buffer.hpp"
"../unified_runtime/ur/adapters/level_zero/adapter.hpp"
"../unified_runtime/ur/adapters/level_zero/common.hpp"
"../unified_runtime/ur/adapters/level_zero/context.hpp"
"../unified_runtime/ur/adapters/level_zero/device.hpp"
Expand All @@ -116,6 +117,7 @@ add_sycl_plugin(level_zero
"../unified_runtime/ur/adapters/level_zero/sampler.hpp"
"../unified_runtime/ur/adapters/level_zero/usm.hpp"
"../unified_runtime/ur/adapters/level_zero/ur_level_zero.cpp"
"../unified_runtime/ur/adapters/level_zero/adapter.cpp"
"../unified_runtime/ur/adapters/level_zero/command_buffer.cpp"
"../unified_runtime/ur/adapters/level_zero/common.cpp"
"../unified_runtime/ur/adapters/level_zero/context.cpp"
Expand Down
2 changes: 1 addition & 1 deletion sycl/plugins/native_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_sycl_plugin(native_cpu
"../unified_runtime/pi2ur.cpp"
"../unified_runtime/ur/ur.hpp"
"../unified_runtime/ur/ur.cpp"
"../unified_runtime/ur/adapters/native_cpu/adapter.cpp"
"../unified_runtime/ur/adapters/native_cpu/common.cpp"
"../unified_runtime/ur/adapters/native_cpu/common.hpp"
"../unified_runtime/ur/adapters/native_cpu/context.cpp"
Expand All @@ -24,7 +25,6 @@ add_sycl_plugin(native_cpu
"../unified_runtime/ur/adapters/native_cpu/queue.cpp"
"../unified_runtime/ur/adapters/native_cpu/queue.hpp"
"../unified_runtime/ur/adapters/native_cpu/sampler.cpp"
"../unified_runtime/ur/adapters/native_cpu/runtime.cpp"
"../unified_runtime/ur/adapters/native_cpu/ur_interface_loader.cpp"
"../unified_runtime/ur/adapters/native_cpu/usm.cpp"
"../unified_runtime/ur/adapters/native_cpu/usm_p2p.cpp"
Expand Down
10 changes: 8 additions & 2 deletions sycl/plugins/unified_runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ if (NOT DEFINED UNIFIED_RUNTIME_LIBRARY OR NOT DEFINED UNIFIED_RUNTIME_INCLUDE_D
include(FetchContent)

set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git")
set(UNIFIED_RUNTIME_TAG 3c6f02c7a76a0448a83932d93c2dbeff25af70aa)
set(UNIFIED_RUNTIME_TAG 974a7d64dd1a26ede1ff27919b3b8713b848c376)

message(STATUS "Will fetch Unified Runtime from ${UNIFIED_RUNTIME_REPO}")
FetchContent_Declare(unified-runtime
Expand Down Expand Up @@ -86,6 +86,7 @@ add_sycl_library("ur_adapter_level_zero" SHARED
"ur/adapters/level_zero/ur_level_zero.hpp"
"ur/adapters/level_zero/ur_level_zero.cpp"
"ur/adapters/level_zero/ur_interface_loader.cpp"
"ur/adapters/level_zero/adapter.hpp"
"ur/adapters/level_zero/command_buffer.hpp"
"ur/adapters/level_zero/common.hpp"
"ur/adapters/level_zero/context.hpp"
Expand All @@ -100,6 +101,7 @@ add_sycl_library("ur_adapter_level_zero" SHARED
"ur/adapters/level_zero/queue.hpp"
"ur/adapters/level_zero/sampler.hpp"
"ur/adapters/level_zero/usm.hpp"
"ur/adapters/level_zero/adapter.cpp"
"ur/adapters/level_zero/command_buffer.cpp"
"ur/adapters/level_zero/common.cpp"
"ur/adapters/level_zero/context.cpp"
Expand Down Expand Up @@ -135,6 +137,8 @@ if ("cuda" IN_LIST SYCL_ENABLE_PLUGINS)
"ur/ur.cpp"
"ur/usm_allocator.cpp"
"ur/usm_allocator.hpp"
"ur/adapters/cuda/adapter.cpp"
"ur/adapters/cuda/adapter.hpp"
"ur/adapters/cuda/command_buffer.cpp"
"ur/adapters/cuda/command_buffer.hpp"
"ur/adapters/cuda/common.cpp"
Expand Down Expand Up @@ -186,6 +190,8 @@ if ("hip" IN_LIST SYCL_ENABLE_PLUGINS)
"ur/ur.cpp"
"ur/usm_allocator.cpp"
"ur/usm_allocator.hpp"
"ur/adapters/hip/adapter.cpp"
"ur/adapters/hip/adapter.hpp"
"ur/adapters/hip/command_buffer.cpp"
"ur/adapters/hip/command_buffer.hpp"
"ur/adapters/hip/common.cpp"
Expand Down Expand Up @@ -243,6 +249,7 @@ if("native_cpu" IN_LIST SYCL_ENABLE_PLUGINS)
SOURCES
"ur/ur.cpp"
"ur/ur.hpp"
"ur/adapters/native_cpu/adapter.cpp"
"ur/adapters/native_cpu/common.cpp"
"ur/adapters/native_cpu/common.hpp"
"ur/adapters/native_cpu/context.cpp"
Expand All @@ -262,7 +269,6 @@ if("native_cpu" IN_LIST SYCL_ENABLE_PLUGINS)
"ur/adapters/native_cpu/queue.cpp"
"ur/adapters/native_cpu/queue.hpp"
"ur/adapters/native_cpu/sampler.cpp"
"ur/adapters/native_cpu/runtime.cpp"
"ur/adapters/native_cpu/ur_interface_loader.cpp"
"ur/adapters/native_cpu/usm.cpp"
"ur/adapters/native_cpu/usm_p2p.cpp"
Expand Down
45 changes: 41 additions & 4 deletions sycl/plugins/unified_runtime/pi2ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,22 @@ namespace pi2ur {

inline pi_result piTearDown(void *PluginParameter) {
std::ignore = PluginParameter;
// Fetch the single known adapter (the one which is statically linked) so we
// can release it. Fetching it for a second time (after piPlatformsGet)
// increases the reference count, so we need to release it twice.
// pi_unified_runtime has its own implementation of piTearDown.
static std::once_flag AdapterReleaseFlag;
ur_adapter_handle_t Adapter;
ur_result_t Ret = UR_RESULT_SUCCESS;
std::call_once(AdapterReleaseFlag, [&]() {
Ret = urAdapterGet(1, &Adapter, nullptr);
if (Ret == UR_RESULT_SUCCESS) {
Ret = urAdapterRelease(Adapter);
Ret = urAdapterRelease(Adapter);
}
});
HANDLE_ERRORS(Ret);

// TODO: Dont check for errors in urTearDown, since
// when using Level Zero plugin, the second urTearDown
// will fail as ur_loader.so has already been unloaded,
Expand All @@ -731,9 +747,20 @@ inline pi_result piTearDown(void *PluginParameter) {
inline pi_result piPlatformsGet(pi_uint32 NumEntries, pi_platform *Platforms,
pi_uint32 *NumPlatforms) {

urInit(0);
urInit(0, nullptr);
// We're not going through the UR loader so we're guaranteed to have exactly
// one adapter (whichever is statically linked). The PI plugin for UR has its
// own implementation of piPlatformsGet.
static ur_adapter_handle_t Adapter;
static std::once_flag AdapterGetFlag;
ur_result_t Ret = UR_RESULT_SUCCESS;
std::call_once(AdapterGetFlag,
[&Ret]() { Ret = urAdapterGet(1, &Adapter, nullptr); });
HANDLE_ERRORS(Ret);

auto phPlatforms = reinterpret_cast<ur_platform_handle_t *>(Platforms);
HANDLE_ERRORS(urPlatformGet(NumEntries, phPlatforms, NumPlatforms));
HANDLE_ERRORS(
urPlatformGet(&Adapter, 1, NumEntries, phPlatforms, NumPlatforms));
return PI_SUCCESS;
}

Expand Down Expand Up @@ -894,8 +921,18 @@ inline pi_result piDeviceRelease(pi_device Device) {
return PI_SUCCESS;
}

inline pi_result piPluginGetLastError(char **message) {
std::ignore = message;
inline pi_result piPluginGetLastError(char **Message) {
// We're not going through the UR loader so we're guaranteed to have exactly
// one adapter (whichever is statically linked). The PI plugin for UR has its
// own implementation of piPluginGetLastError. Materialize the adapter
// reference for the urAdapterGetLastError call, then release it.
ur_adapter_handle_t Adapter;
urAdapterGet(1, &Adapter, nullptr);
int32_t ErrorCode;
urAdapterGetLastError(Adapter, const_cast<const char **>(Message),
&ErrorCode);
urAdapterRelease(Adapter);

return PI_SUCCESS;
}

Expand Down
44 changes: 38 additions & 6 deletions sycl/plugins/unified_runtime/pi_unified_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,34 @@ static void DieUnsupported() {
die("Unified Runtime: functionality is not supported");
}

// Adapters may be released by piTearDown being called, or the global dtors
// being called first. Handle releasing the adapters exactly once.
static void releaseAdapters(std::vector<ur_adapter_handle_t> &Vec) {
static std::once_flag ReleaseFlag{};
std::call_once(ReleaseFlag, [&]() {
for (auto Adapter : Vec) {
urAdapterRelease(Adapter);
}
urTearDown(nullptr);
});
}

struct AdapterHolder {
~AdapterHolder() { releaseAdapters(Vec); }
std::vector<ur_adapter_handle_t> Vec{};
} Adapters;

// All PI API interfaces are C interfaces
extern "C" {
__SYCL_EXPORT pi_result piPlatformsGet(pi_uint32 NumEntries,
pi_platform *Platforms,
pi_uint32 *NumPlatforms) {
return pi2ur::piPlatformsGet(NumEntries, Platforms, NumPlatforms);
// Get all the platforms from all available adapters
urPlatformGet(Adapters.Vec.data(), static_cast<uint32_t>(Adapters.Vec.size()),
NumEntries, reinterpret_cast<ur_platform_handle_t *>(Platforms),
NumPlatforms);

return PI_SUCCESS;
}

__SYCL_EXPORT pi_result piPlatformGetInfo(pi_platform Platform,
Expand Down Expand Up @@ -1122,6 +1144,12 @@ __SYCL_EXPORT pi_result piextPeerAccessGetInfo(
ParamValueSizeRet);
}

__SYCL_EXPORT pi_result piTearDown(void *) {
releaseAdapters(Adapters.Vec);
urTearDown(nullptr);
return PI_SUCCESS;
}

__SYCL_EXPORT pi_result piextMemImageAllocate(pi_context Context,
pi_device Device,
pi_image_format *ImageFormat,
Expand Down Expand Up @@ -1256,11 +1284,6 @@ __SYCL_EXPORT pi_result piextSignalExternalSemaphore(
Queue, SemHandle, NumEventsInWaitList, EventWaitList, Event);
}

// This interface is not in Unified Runtime currently
__SYCL_EXPORT pi_result piTearDown(void *PluginParameter) {
return pi2ur::piTearDown(PluginParameter);
}

// This interface is not in Unified Runtime currently
__SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
PI_ASSERT(PluginInit, PI_ERROR_INVALID_VALUE);
Expand All @@ -1279,6 +1302,15 @@ __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {

strncpy(PluginInit->PluginVersion, SupportedVersion, PluginVersionSize);

// Initialize UR and discover adapters
HANDLE_ERRORS(urInit(0, nullptr));
uint32_t NumAdapters;
HANDLE_ERRORS(urAdapterGet(0, nullptr, &NumAdapters));
if (NumAdapters > 0) {
Adapters.Vec.resize(NumAdapters);
HANDLE_ERRORS(urAdapterGet(NumAdapters, Adapters.Vec.data(), nullptr));
}

// Bind interfaces that are already supported and "die" for unsupported ones
#define _PI_API(api) \
(PluginInit->PiFunctionTable).api = (decltype(&::api))(&DieUnsupported);
Expand Down
89 changes: 89 additions & 0 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===--------- adapter.cpp - CUDA Adapter ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===-----------------------------------------------------------------===//

#include <ur_api.h>

#include "common.hpp"

void enableCUDATracing();
void disableCUDATracing();

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;
};

ur_adapter_handle_t_ adapter{};

UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t,
ur_loader_config_handle_t) {
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urTearDown(void *) {
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
uint32_t *pNumAdapters) {
if (NumEntries > 0 && phAdapters) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (adapter.RefCount++ == 0) {
enableCUDATracing();
}

*phAdapters = &adapter;
}

if (pNumAdapters) {
*pNumAdapters = 1;
}

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
adapter.RefCount++;

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (--adapter.RefCount == 0) {
disableCUDATracing();
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
ur_adapter_handle_t, const char **ppMessage, int32_t *pError) {
*ppMessage = ErrorMessage;
*pError = ErrorMessageCode;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
ur_adapter_info_t propName,
size_t propSize,
void *pPropValue,
size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

switch (propName) {
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_ADAPTER_BACKEND_CUDA);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(adapter.RefCount.load());
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

return UR_RESULT_SUCCESS;
}
11 changes: 11 additions & 0 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//===--------- adapter.cpp - CUDA Adapter ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===-----------------------------------------------------------------===//

struct ur_adapter_handle_t_;

extern ur_adapter_handle_t_ adapter;
7 changes: 0 additions & 7 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,3 @@ void setPluginSpecificMessage(CUresult cu_res) {
setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
free(message);
}

// Returns plugin specific error and warning messages; common implementation
// that can be shared between adapters
ur_result_t urGetLastResult(ur_platform_handle_t, const char **ppMessage) {
*ppMessage = &ErrorMessage[0];
return ErrorMessageCode;
}
7 changes: 5 additions & 2 deletions sycl/plugins/unified_runtime/ur/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cassert>
#include <sstream>

#include "adapter.hpp"
#include "context.hpp"
#include "device.hpp"
#include "platform.hpp"
Expand Down Expand Up @@ -1206,13 +1207,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(

// Get list of platforms
uint32_t NumPlatforms = 0;
ur_result_t Result = urPlatformGet(0, nullptr, &NumPlatforms);
ur_adapter_handle_t AdapterHandle = &adapter;
ur_result_t Result =
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
if (Result != UR_RESULT_SUCCESS)
return Result;

ur_platform_handle_t *Plat = static_cast<ur_platform_handle_t *>(
malloc(NumPlatforms * sizeof(ur_platform_handle_t)));
Result = urPlatformGet(NumPlatforms, Plat, nullptr);
Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, Plat, nullptr);
if (Result != UR_RESULT_SUCCESS)
return Result;

Expand Down
Loading