diff --git a/sycl/source/detail/config.def b/sycl/source/detail/config.def index a685e64076783..74f5cf2a693ab 100644 --- a/sycl/source/detail/config.def +++ b/sycl/source/detail/config.def @@ -31,3 +31,7 @@ CONFIG(SYCL_CACHE_THRESHOLD, 16, __SYCL_CACHE_THRESHOLD) CONFIG(SYCL_CACHE_MIN_DEVICE_IMAGE_SIZE, 16, __SYCL_CACHE_MIN_DEVICE_IMAGE_SIZE) CONFIG(SYCL_CACHE_MAX_DEVICE_IMAGE_SIZE, 16, __SYCL_CACHE_MAX_DEVICE_IMAGE_SIZE) CONFIG(INTEL_ENABLE_OFFLOAD_ANNOTATIONS, 1, __SYCL_INTEL_ENABLE_OFFLOAD_ANNOTATIONS) +CONFIG(SYCL_OVERRIDE_PI_OPENCL, 1024, __SYCL_OVERRIDE_PI_OPENCL) +CONFIG(SYCL_OVERRIDE_PI_LEVEL_ZERO, 1024, __SYCL_OVERRIDE_PI_LEVEL_ZERO) +CONFIG(SYCL_OVERRIDE_PI_CUDA, 1024, __SYCL_OVERRIDE_PI_CUDA) +CONFIG(SYCL_OVERRIDE_PI_ROCM, 1024, __SYCL_OVERRIDE_PI_ROCM) diff --git a/sycl/source/detail/pi.cpp b/sycl/source/detail/pi.cpp index 25109bd3a6669..4a9dc40a4a79b 100644 --- a/sycl/source/detail/pi.cpp +++ b/sycl/source/detail/pi.cpp @@ -225,19 +225,34 @@ std::string memFlagsToString(pi_mem_flags Flags) { std::shared_ptr GlobalPlugin; // Find the plugin at the appropriate location and return the location. -bool findPlugins(std::vector> &PluginNames) { +std::vector> findPlugins() { + std::vector> PluginNames; + // TODO: Based on final design discussions, change the location where the // plugin must be searched; how to identify the plugins etc. Currently the // search is done for libpi_opencl.so/pi_opencl.dll file in LD_LIBRARY_PATH // env only. // + const char *OpenCLPluginName = + SYCLConfig::get() + ? SYCLConfig::get() + : __SYCL_OPENCL_PLUGIN_NAME; + const char *L0PluginName = + SYCLConfig::get() + ? SYCLConfig::get() + : __SYCL_LEVEL_ZERO_PLUGIN_NAME; + const char *CUDAPluginName = SYCLConfig::get() + ? SYCLConfig::get() + : __SYCL_CUDA_PLUGIN_NAME; + const char *ROCMPluginName = SYCLConfig::get() + ? SYCLConfig::get() + : __SYCL_ROCM_PLUGIN_NAME; device_filter_list *FilterList = SYCLConfig::get(); if (!FilterList) { - PluginNames.emplace_back(__SYCL_OPENCL_PLUGIN_NAME, backend::opencl); - PluginNames.emplace_back(__SYCL_LEVEL_ZERO_PLUGIN_NAME, - backend::level_zero); - PluginNames.emplace_back(__SYCL_CUDA_PLUGIN_NAME, backend::cuda); - PluginNames.emplace_back(__SYCL_ROCM_PLUGIN_NAME, backend::rocm); + PluginNames.emplace_back(OpenCLPluginName, backend::opencl); + PluginNames.emplace_back(L0PluginName, backend::level_zero); + PluginNames.emplace_back(CUDAPluginName, backend::cuda); + PluginNames.emplace_back(ROCMPluginName, backend::rocm); } else { std::vector Filters = FilterList->get(); bool OpenCLFound = false; @@ -248,26 +263,25 @@ bool findPlugins(std::vector> &PluginNames) { backend Backend = Filter.Backend; if (!OpenCLFound && (Backend == backend::opencl || Backend == backend::all)) { - PluginNames.emplace_back(__SYCL_OPENCL_PLUGIN_NAME, backend::opencl); + PluginNames.emplace_back(OpenCLPluginName, backend::opencl); OpenCLFound = true; } if (!LevelZeroFound && (Backend == backend::level_zero || Backend == backend::all)) { - PluginNames.emplace_back(__SYCL_LEVEL_ZERO_PLUGIN_NAME, - backend::level_zero); + PluginNames.emplace_back(L0PluginName, backend::level_zero); LevelZeroFound = true; } if (!CudaFound && (Backend == backend::cuda || Backend == backend::all)) { - PluginNames.emplace_back(__SYCL_CUDA_PLUGIN_NAME, backend::cuda); + PluginNames.emplace_back(CUDAPluginName, backend::cuda); CudaFound = true; } if (!RocmFound && (Backend == backend::rocm || Backend == backend::all)) { - PluginNames.emplace_back(__SYCL_ROCM_PLUGIN_NAME, backend::rocm); + PluginNames.emplace_back(ROCMPluginName, backend::rocm); RocmFound = true; } } } - return true; + return PluginNames; } // Load the Plugin by calling the OS dependent library loading call. @@ -321,8 +335,7 @@ const std::vector &initialize() { } static void initializePlugins(std::vector *Plugins) { - std::vector> PluginNames; - findPlugins(PluginNames); + std::vector> PluginNames = findPlugins(); if (PluginNames.empty() && trace(PI_TRACE_ALL)) std::cerr << "SYCL_PI_TRACE[all]: " diff --git a/sycl/test/basic_tests/pluign_overrides_negative.cpp b/sycl/test/basic_tests/pluign_overrides_negative.cpp new file mode 100644 index 0000000000000..cecf969fa07f8 --- /dev/null +++ b/sycl/test/basic_tests/pluign_overrides_negative.cpp @@ -0,0 +1,16 @@ +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: env SYCL_OVERRIDE_PI_OPENCL=opencl_test env SYCL_OVERRIDE_PI_LEVEL_ZERO=l0_test env SYCL_OVERRIDE_PI_CUDA=cuda_test env SYCL_OVERRIDE_PI_ROCM=rocm_test env SYCL_PI_TRACE=-1 %t.out > %t.log 2>&1 +// RUN: FileCheck %s --input-file %t.log + +#include + +int main() { + sycl::queue Q; + + return 0; +} + +// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: opencl_test +// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: l0_test +// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: cuda_test +// CHECK: SYCL_PI_TRACE[all]: Check if plugin is present. Failed to load plugin: rocm_test diff --git a/sycl/test/basic_tests/pluign_overrides_positive.cpp b/sycl/test/basic_tests/pluign_overrides_positive.cpp new file mode 100644 index 0000000000000..63f44ecdf056e --- /dev/null +++ b/sycl/test/basic_tests/pluign_overrides_positive.cpp @@ -0,0 +1,43 @@ +// RUN: %clangxx -fsycl -DFAKE_PLUGIN -shared %s -o %t_fake_plugin.so +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: env SYCL_OVERRIDE_PI_OPENCL=%t_fake_plugin.so env SYCL_OVERRIDE_PI_LEVEL_ZERO=%t_fake_plugin.so env SYCL_OVERRIDE_PI_CUDA=%t_fake_plugin.so env SYCL_OVERRIDE_PI_ROCM=%t_fake_plugin.so env SYCL_PI_TRACE=-1 %t.out > %t.log 2>&1 +// RUN: FileCheck %s --input-file %t.log +// REQUIRES: linux + +#ifdef FAKE_PLUGIN + +#include + +pi_result piPlatformsGet(pi_uint32 NumEntries, pi_platform *Platforms, + pi_uint32 *NumPlatforms) { + return PI_INVALID_OPERATION; +} + +pi_result piTearDown(void *) { return PI_SUCCESS; } + +pi_result piPluginInit(pi_plugin *PluginInit) { + PluginInit->PiFunctionTable.piPlatformsGet = piPlatformsGet; + PluginInit->PiFunctionTable.piTearDown = piTearDown; + return PI_SUCCESS; +} + +#else + +#include + +int main() { + try { + sycl::platform P{sycl::default_selector{}}; + } catch (...) { + // NOP + } + + return 0; +} + +#endif + +// CHECK: SYCL_PI_TRACE[basic]: Plugin found and successfully loaded: {{[0-9a-zA-Z_\/\.-]+}}_fake_plugin.so +// CHECK: SYCL_PI_TRACE[basic]: Plugin found and successfully loaded: {{[0-9a-zA-Z_\/\.-]+}}_fake_plugin.so +// CHECK: SYCL_PI_TRACE[basic]: Plugin found and successfully loaded: {{[0-9a-zA-Z_\/\.-]+}}_fake_plugin.so +// CHECK: SYCL_PI_TRACE[basic]: Plugin found and successfully loaded: {{[0-9a-zA-Z_\/\.-]+}}_fake_plugin.so