diff --git a/scripts/core/INTRO.rst b/scripts/core/INTRO.rst index 448e3569e2..898d4ce5f3 100644 --- a/scripts/core/INTRO.rst +++ b/scripts/core/INTRO.rst @@ -396,6 +396,14 @@ Specific environment variables can be set to control the behavior of unified run See the Layers_ section for details of the layers currently included in the runtime. +.. envvar:: UR_LOADER_PRELOAD_FILTER + + If set, the loader will read `ONEAPI_DEVICE_SELECTOR` before loading the UR Adapters to determine which backends should be loaded. + + .. note:: + + This environment variable is default enabled on Linux, but default disabled on Windows. + Service identifiers --------------------- diff --git a/source/loader/ur_adapter_registry.hpp b/source/loader/ur_adapter_registry.hpp index 25cd9a9fff..ff07eca100 100644 --- a/source/loader/ur_adapter_registry.hpp +++ b/source/loader/ur_adapter_registry.hpp @@ -152,10 +152,123 @@ class AdapterRegistry { return paths.empty() ? std::nullopt : std::optional(paths); } + ur_result_t readPreFilterODS(std::string platformBackendName) { + // TODO: Refactor this to the common code such that both the prefilter and urDeviceGetSelected use the same functionality. + bool acceptLibrary = true; + std::optional odsEnvMap; + try { + odsEnvMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", false); + + } catch (...) { + // If the selector is malformed, then we ignore selector and return success. + logger::error("ERROR: missing backend, format of filter = " + "'[!]backend:filterStrings'"); + return UR_RESULT_SUCCESS; + } + logger::debug( + "getenv_to_map parsed env var and {} a map", + (odsEnvMap.has_value() ? "produced" : "failed to produce")); + + // if the ODS env var is not set at all, then pretend it was set to the default + using EnvVarMap = std::map>; + EnvVarMap mapODS = + odsEnvMap.has_value() ? odsEnvMap.value() : EnvVarMap{{"*", {"*"}}}; + for (auto &termPair : mapODS) { + std::string backend = termPair.first; + // TODO: Figure out how to process all ODS errors rather than returning + // on the first error. + if (backend.empty()) { + // FIXME: never true because getenv_to_map rejects this case + // malformed term: missing backend -- output ERROR, then continue + logger::error("ERROR: missing backend, format of filter = " + "'[!]backend:filterStrings'"); + continue; + } + logger::debug("ONEAPI_DEVICE_SELECTOR Pre-Filter with backend '{}' " + "and platform library name '{}'", + backend, platformBackendName); + enum FilterType { + AcceptFilter, + DiscardFilter, + } termType = + (backend.front() != '!') ? AcceptFilter : DiscardFilter; + logger::debug( + "termType is {}", + (termType != AcceptFilter ? "DiscardFilter" : "AcceptFilter")); + if (termType != AcceptFilter) { + logger::debug("DEBUG: backend was '{}'", backend); + backend.erase(backend.cbegin()); + logger::debug("DEBUG: backend now '{}'", backend); + } + + // Verify that the backend string is valid, otherwise ignore the backend. + if ((strcmp(backend.c_str(), "*") != 0) && + (strcmp(backend.c_str(), "level_zero") != 0) && + (strcmp(backend.c_str(), "opencl") != 0) && + (strcmp(backend.c_str(), "cuda") != 0) && + (strcmp(backend.c_str(), "hip") != 0)) { + logger::debug("ONEAPI_DEVICE_SELECTOR Pre-Filter with illegal " + "backend '{}' ", + backend); + continue; + } + + // case-insensitive comparison by converting both tolower + std::transform(platformBackendName.begin(), + platformBackendName.end(), + platformBackendName.begin(), + [](unsigned char c) { return std::tolower(c); }); + std::transform(backend.begin(), backend.end(), backend.begin(), + [](unsigned char c) { return std::tolower(c); }); + std::size_t nameFound = platformBackendName.find(backend); + + bool backendFound = nameFound != std::string::npos; + if (termType == AcceptFilter) { + if (backend.front() != '*' && !backendFound) { + logger::debug( + "The ONEAPI_DEVICE_SELECTOR backend name '{}' was not " + "found in the platform library name '{}'", + backend, platformBackendName); + acceptLibrary = false; + continue; + } else if (backend.front() == '*' || backendFound) { + return UR_RESULT_SUCCESS; + } + } else { + if (backendFound || backend.front() == '*') { + acceptLibrary = false; + logger::debug( + "The ONEAPI_DEVICE_SELECTOR backend name for discard " + "'{}' was found in the platform library name '{}'", + backend, platformBackendName); + continue; + } + } + } + if (acceptLibrary) { + return UR_RESULT_SUCCESS; + } + return UR_RESULT_ERROR_INVALID_VALUE; + } + void discoverKnownAdapters() { auto searchPathsEnvOpt = getEnvAdapterSearchPaths(); auto loaderLibPathOpt = getLoaderLibPath(); +#if defined(_WIN32) + bool loaderPreFilter = getenv_tobool("UR_LOADER_PRELOAD_FILTER", false); +#else + bool loaderPreFilter = getenv_tobool("UR_LOADER_PRELOAD_FILTER", true); +#endif for (const auto &adapterName : knownAdapterNames) { + + if (loaderPreFilter) { + if (readPreFilterODS(adapterName) != UR_RESULT_SUCCESS) { + logger::debug("The adapter '{}' was removed based on the " + "pre-filter from ONEAPI_DEVICE_SELECTOR.", + adapterName); + continue; + } + } std::vector loadPaths; // Adapter search order: diff --git a/test/loader/adapter_registry/CMakeLists.txt b/test/loader/adapter_registry/CMakeLists.txt index 2778ad5c40..6d80430e6c 100644 --- a/test/loader/adapter_registry/CMakeLists.txt +++ b/test/loader/adapter_registry/CMakeLists.txt @@ -51,3 +51,7 @@ add_adapter_reg_search_test(search-order SEARCH_PATH ${TEST_SEARCH_PATH} ENVS "TEST_ADAPTER_SEARCH_PATH=\"${TEST_SEARCH_PATH}\"" "TEST_CUR_SEARCH_PATH=\"${TEST_BIN_PATH}\"" SOURCES search_order.cpp) + +add_adapter_reg_search_test(prefilter + SEARCH_PATH "" + SOURCES prefilter.cpp) diff --git a/test/loader/adapter_registry/fixtures.hpp b/test/loader/adapter_registry/fixtures.hpp index 79a831d40f..da5c963e8a 100644 --- a/test/loader/adapter_registry/fixtures.hpp +++ b/test/loader/adapter_registry/fixtures.hpp @@ -74,5 +74,49 @@ struct adapterRegSearchTest : ::testing::Test { } } }; +#ifndef _WIN32 +struct adapterPreFilterTest : ::testing::Test { + ur_loader::AdapterRegistry *registry; + const fs::path levelzeroLibName = + MAKE_LIBRARY_NAME("ur_adapter_level_zero", "0"); + std::function islevelzeroLibName = + [this](const fs::path &path) { return path == levelzeroLibName; }; + + std::function &)> haslevelzeroLibName = + [this](const std::vector &paths) { + return std::any_of(paths.cbegin(), paths.cend(), + islevelzeroLibName); + }; + + const fs::path openclLibName = MAKE_LIBRARY_NAME("ur_adapter_opencl", "0"); + std::function isOpenclLibName = + [this](const fs::path &path) { return path == openclLibName; }; + + std::function &)> hasOpenclLibName = + [this](const std::vector &paths) { + return std::any_of(paths.cbegin(), paths.cend(), isOpenclLibName); + }; + + const fs::path cudaLibName = MAKE_LIBRARY_NAME("ur_adapter_cuda", "0"); + std::function isCudaLibName = + [this](const fs::path &path) { return path == cudaLibName; }; + + std::function &)> hasCudaLibName = + [this](const std::vector &paths) { + return std::any_of(paths.cbegin(), paths.cend(), isCudaLibName); + }; + + void SetUp(std::string filter) { + try { + setenv("ONEAPI_DEVICE_SELECTOR", filter.c_str(), 1); + registry = new ur_loader::AdapterRegistry; + } catch (const std::invalid_argument &e) { + FAIL() << e.what(); + } + } + void SetUp() override {} + void TearDown() override { delete registry; } +}; +#endif #endif // UR_ADAPTER_REG_TEST_HELPERS_H diff --git a/test/loader/adapter_registry/prefilter.cpp b/test/loader/adapter_registry/prefilter.cpp new file mode 100644 index 0000000000..1d2b095da3 --- /dev/null +++ b/test/loader/adapter_registry/prefilter.cpp @@ -0,0 +1,140 @@ +// Copyright (C) 2024 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "fixtures.hpp" + +#ifndef _WIN32 + +TEST_F(adapterPreFilterTest, testPrefilterAcceptFilterSingleBackend) { + SetUp("level_zero:*"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_TRUE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_FALSE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_FALSE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterAcceptFilterMultipleBackends) { + SetUp("level_zero:*;opencl:*"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_TRUE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_TRUE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_FALSE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterDiscardFilterSingleBackend) { + SetUp("!level_zero:*"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_FALSE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_TRUE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_TRUE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterDiscardFilterMultipleBackends) { + SetUp("!level_zero:*;!cuda:*"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_FALSE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_TRUE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_FALSE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterAcceptAndDiscardFilter) { + SetUp("!cuda:*;level_zero:*"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_TRUE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_FALSE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_FALSE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterDiscardFilterAll) { + SetUp("*"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_TRUE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_TRUE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_TRUE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterWithInvalidMissingBackend) { + SetUp(":garbage"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_TRUE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_TRUE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_TRUE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterWithInvalidBackend) { + SetUp("garbage:0"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_TRUE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_TRUE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_TRUE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterWithNotAllAndAcceptFilter) { + SetUp("!*;level_zero"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_TRUE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_FALSE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_FALSE(cudaExists); +} + +TEST_F(adapterPreFilterTest, testPrefilterWithNotAllFilter) { + SetUp("!*"); + auto levelZeroExists = + std::any_of(registry->cbegin(), registry->cend(), haslevelzeroLibName); + EXPECT_FALSE(levelZeroExists); + auto openclExists = + std::any_of(registry->cbegin(), registry->cend(), hasOpenclLibName); + EXPECT_FALSE(openclExists); + auto cudaExists = + std::any_of(registry->cbegin(), registry->cend(), hasCudaLibName); + EXPECT_FALSE(cudaExists); +} + +#endif