diff --git a/sycl/include/sycl/detail/ur.hpp b/sycl/include/sycl/detail/ur.hpp index 48e876e3adb6f..c845fb1440aad 100644 --- a/sycl/include/sycl/detail/ur.hpp +++ b/sycl/include/sycl/detail/ur.hpp @@ -64,18 +64,6 @@ template __SYCL_EXPORT void *getPluginOpaqueData(void *opaquedata_arg); namespace ur { -// Function to load a shared library -// Implementation is OS dependent -void *loadOsLibrary(const std::string &Library); - -// Function to unload a shared library -// Implementation is OS dependent (see posix-ur.cpp and windows-ur.cpp) -int unloadOsLibrary(void *Library); - -// Function to get Address of a symbol defined in the shared -// library, implementation is OS dependent. -void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName); - // Performs UR one-time initialization. std::vector & initializeUr(ur_loader_config_handle_t LoaderConfig = nullptr); diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index d02dbb725637a..7c510d4ebbd34 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -296,7 +296,8 @@ set(SYCL_COMMON_SOURCES "spirv_ops.cpp" "virtual_mem.cpp" "$<$:detail/windows_ur.cpp>" - "$<$,$>:detail/posix_ur.cpp>" + "$<$:detail/windows_dlopen.cpp>" + "$<$,$>:detail/posix_dlopen.cpp>" ) set(SYCL_NON_PREVIEW_SOURCES "${SYCL_COMMON_SOURCES}" diff --git a/sycl/source/detail/dlopen_utils.hpp b/sycl/source/detail/dlopen_utils.hpp new file mode 100644 index 0000000000000..2190bc77570bf --- /dev/null +++ b/sycl/source/detail/dlopen_utils.hpp @@ -0,0 +1,29 @@ +//===------ dlopen_utils - Helpers for libraries loading -------*- C++ -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace sycl { +inline namespace _V1 { +namespace detail { + +// Function to load a shared library +// Implementation is OS dependent +void *loadOsLibrary(const std::string &Library); + +// Function to unload a shared library +// Implementation is OS dependent (see posix-pi.cpp and windows-pi.cpp) +int unloadOsLibrary(void *Library); + +// Function to get Address of a symbol defined in the shared +// library, implementation is OS dependent. +void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName); + +}}} diff --git a/sycl/source/detail/jit_compiler.cpp b/sycl/source/detail/jit_compiler.cpp index 909fc751772dc..d1ae2950072a2 100644 --- a/sycl/source/detail/jit_compiler.cpp +++ b/sycl/source/detail/jit_compiler.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -32,14 +33,14 @@ jit_compiler::jit_compiler() { auto checkJITLibrary = [this]() -> bool { static const std::string JITLibraryName = "libsycl-fusion.so"; - void *LibraryPtr = sycl::detail::ur::loadOsLibrary(JITLibraryName); + void *LibraryPtr = sycl::detail::loadOsLibrary(JITLibraryName); if (LibraryPtr == nullptr) { printPerformanceWarning("Could not find JIT library " + JITLibraryName); return false; } this->AddToConfigHandle = reinterpret_cast( - sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr, + sycl::detail::getOsLibraryFuncAddress(LibraryPtr, "addToJITConfiguration")); if (!this->AddToConfigHandle) { printPerformanceWarning( @@ -48,7 +49,7 @@ jit_compiler::jit_compiler() { } this->ResetConfigHandle = reinterpret_cast( - sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr, + sycl::detail::getOsLibraryFuncAddress(LibraryPtr, "resetJITConfiguration")); if (!this->ResetConfigHandle) { printPerformanceWarning( @@ -57,7 +58,7 @@ jit_compiler::jit_compiler() { } this->FuseKernelsHandle = reinterpret_cast( - sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr, "fuseKernels")); + sycl::detail::getOsLibraryFuncAddress(LibraryPtr, "fuseKernels")); if (!this->FuseKernelsHandle) { printPerformanceWarning( "Cannot resolve JIT library function entry point"); @@ -66,7 +67,7 @@ jit_compiler::jit_compiler() { this->MaterializeSpecConstHandle = reinterpret_cast( - sycl::detail::ur::getOsLibraryFuncAddress( + sycl::detail::getOsLibraryFuncAddress( LibraryPtr, "materializeSpecConstants")); if (!this->MaterializeSpecConstHandle) { printPerformanceWarning( diff --git a/sycl/source/detail/kernel_compiler/kernel_compiler_opencl.cpp b/sycl/source/detail/kernel_compiler/kernel_compiler_opencl.cpp index 3f796f5c647ab..c3f0693296c05 100644 --- a/sycl/source/detail/kernel_compiler/kernel_compiler_opencl.cpp +++ b/sycl/source/detail/kernel_compiler/kernel_compiler_opencl.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include // getOsLibraryFuncAddress +#include #include // make_error_code #include "kernel_compiler_opencl.hpp" @@ -26,7 +26,7 @@ namespace detail { // ensures the OclocLibrary has the right version, etc. void checkOclocLibrary(void *OclocLibrary) { void *OclocVersionHandle = - sycl::detail::ur::getOsLibraryFuncAddress(OclocLibrary, "oclocVersion"); + sycl::detail::getOsLibraryFuncAddress(OclocLibrary, "oclocVersion"); // The initial versions of ocloc library did not have the oclocVersion() // function. Those versions had the same API as the first version of ocloc // library having that oclocVersion() function. @@ -66,7 +66,7 @@ void *loadOclocLibrary() { #endif void *tempPtr = OclocLibrary; if (tempPtr == nullptr) { - tempPtr = sycl::detail::ur::loadOsLibrary(OclocLibraryName); + tempPtr = sycl::detail::loadOsLibrary(OclocLibraryName); if (tempPtr == nullptr) throw sycl::exception(make_error_code(errc::build), @@ -103,11 +103,11 @@ void SetupLibrary(voidPtr &oclocInvokeHandle, voidPtr &oclocFreeOutputHandle, loadOclocLibrary(); oclocInvokeHandle = - sycl::detail::ur::getOsLibraryFuncAddress(OclocLibrary, "oclocInvoke"); + sycl::detail::getOsLibraryFuncAddress(OclocLibrary, "oclocInvoke"); if (!oclocInvokeHandle) throw sycl::exception(the_errc, "Cannot load oclocInvoke() function"); - oclocFreeOutputHandle = sycl::detail::ur::getOsLibraryFuncAddress( + oclocFreeOutputHandle = sycl::detail::getOsLibraryFuncAddress( OclocLibrary, "oclocFreeOutput"); if (!oclocFreeOutputHandle) throw sycl::exception(the_errc, "Cannot load oclocFreeOutput() function"); diff --git a/sycl/source/detail/online_compiler/online_compiler.cpp b/sycl/source/detail/online_compiler/online_compiler.cpp index 5d3c3a381607b..5576728a390fd 100644 --- a/sycl/source/detail/online_compiler/online_compiler.cpp +++ b/sycl/source/detail/online_compiler/online_compiler.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -94,12 +95,12 @@ compileToSPIRV(const std::string &Source, sycl::info::device_type DeviceType, #else static const std::string OclocLibraryName = "libocloc.so"; #endif - void *OclocLibrary = sycl::detail::ur::loadOsLibrary(OclocLibraryName); + void *OclocLibrary = sycl::detail::loadOsLibrary(OclocLibraryName); if (!OclocLibrary) throw online_compile_error("Cannot load ocloc library: " + OclocLibraryName); void *OclocVersionHandle = - sycl::detail::ur::getOsLibraryFuncAddress(OclocLibrary, "oclocVersion"); + sycl::detail::getOsLibraryFuncAddress(OclocLibrary, "oclocVersion"); // The initial versions of ocloc library did not have the oclocVersion() // function. Those versions had the same API as the first version of ocloc // library having that oclocVersion() function. @@ -126,10 +127,10 @@ compileToSPIRV(const std::string &Source, sycl::info::device_type DeviceType, ".N), where (N >= " + std::to_string(CurrentVersionMinor) + ")."); CompileToSPIRVHandle = - sycl::detail::ur::getOsLibraryFuncAddress(OclocLibrary, "oclocInvoke"); + sycl::detail::getOsLibraryFuncAddress(OclocLibrary, "oclocInvoke"); if (!CompileToSPIRVHandle) throw online_compile_error("Cannot load oclocInvoke() function"); - FreeSPIRVOutputsHandle = sycl::detail::ur::getOsLibraryFuncAddress( + FreeSPIRVOutputsHandle = sycl::detail::getOsLibraryFuncAddress( OclocLibrary, "oclocFreeOutput"); if (!FreeSPIRVOutputsHandle) throw online_compile_error("Cannot load oclocFreeOutput() function"); diff --git a/sycl/source/detail/posix_ur.cpp b/sycl/source/detail/posix_dlopen.cpp similarity index 94% rename from sycl/source/detail/posix_ur.cpp rename to sycl/source/detail/posix_dlopen.cpp index 8ca9991a03363..260fd2c11465e 100644 --- a/sycl/source/detail/posix_ur.cpp +++ b/sycl/source/detail/posix_dlopen.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include #include @@ -15,7 +16,7 @@ namespace sycl { inline namespace _V1 { -namespace detail::ur { +namespace detail { void *loadOsLibrary(const std::string &LibraryPath) { // TODO: Check if the option RTLD_NOW is correct. Explore using @@ -35,6 +36,6 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) { return dlsym(Library, FunctionName.c_str()); } -} // namespace detail::ur +} // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/source/detail/windows_dlopen.cpp b/sycl/source/detail/windows_dlopen.cpp new file mode 100644 index 0000000000000..a001ac2dda653 --- /dev/null +++ b/sycl/source/detail/windows_dlopen.cpp @@ -0,0 +1,51 @@ +//==---------------- windows_dlopen.cpp ------------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "dlopen_utils.hpp" + +#include +#include +#include + +namespace sycl { +inline namespace _V1 { +namespace detail { + +void *loadOsLibrary(const std::string &LibraryPath) { + // Tells the system to not display the critical-error-handler message box. + // Instead, the system sends the error to the calling process. + // This is crucial for graceful handling of shared libs that can't be + // loaded, e.g. due to missing native run-times. + + UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS); + // Exclude current directory from DLL search path + if (!SetDllDirectoryA("")) { + assert(false && "Failed to update DLL search path"); + } + + auto Result = (void *)LoadLibraryExA(LibraryPath.c_str(), NULL, NULL); + (void)SetErrorMode(SavedMode); + if (!SetDllDirectoryA(nullptr)) { + assert(false && "Failed to restore DLL search path"); + } + + return Result; +} + +int unloadOsLibrary(void *Library) { + return (int)FreeLibrary((HMODULE)Library); +} + +void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) { + return reinterpret_cast( + GetProcAddress((HMODULE)Library, FunctionName.c_str())); +} + +} // namespace detail +} // namespace _V1 +} // namespace sycl diff --git a/sycl/source/detail/windows_ur.cpp b/sycl/source/detail/windows_ur.cpp index f730b087a67af..4071cc1717457 100644 --- a/sycl/source/detail/windows_ur.cpp +++ b/sycl/source/detail/windows_ur.cpp @@ -6,52 +6,38 @@ // //===----------------------------------------------------------------------===// +#include + #include #include #include #include #include -#include #include "detail/windows_os_utils.hpp" #include "ur_win_proxy_loader.hpp" namespace sycl { inline namespace _V1 { -namespace detail { -namespace ur { - -void *loadOsLibrary(const std::string &LibraryPath) { - // Tells the system to not display the critical-error-handler message box. - // Instead, the system sends the error to the calling process. - // This is crucial for graceful handling of shared libs that can't be - // loaded, e.g. due to missing native run-times. - - UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS); - // Exclude current directory from DLL search path - if (!SetDllDirectoryA("")) { - assert(false && "Failed to update DLL search path"); - } +namespace detail::ur { - auto Result = (void *)LoadLibraryExA(LibraryPath.c_str(), NULL, NULL); - (void)SetErrorMode(SavedMode); - if (!SetDllDirectoryA(nullptr)) { - assert(false && "Failed to restore DLL search path"); - } +void *loadOsPluginLibrary(const std::string &PluginPath) { + // We fetch the preloaded plugin from the pi_win_proxy_loader. + // The proxy_loader handles any required error suppression. + auto Result = getPreloadedPlugin(PluginPath); return Result; } -int unloadOsLibrary(void *Library) { +int unloadOsPluginLibrary(void *Library) { + // The mock plugin does not have an associated library, so we allow nullptr + // here to avoid it trying to free a non-existent library. + if (!Library) + return 1; return (int)FreeLibrary((HMODULE)Library); } -void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) { - return reinterpret_cast( - GetProcAddress((HMODULE)Library, FunctionName.c_str())); -} - static std::filesystem::path getCurrentDSODirPath() { wchar_t Path[MAX_PATH]; auto Handle = @@ -70,7 +56,6 @@ static std::filesystem::path getCurrentDSODirPath() { return std::filesystem::path(Path); } -} // namespace ur -} // namespace detail +} // namespace detail::ur } // namespace _V1 } // namespace sycl