From f40b269bd949276f53e873fd3177a71405804111 Mon Sep 17 00:00:00 2001 From: "Zhang, Yantao" Date: Tue, 21 Jan 2025 16:01:33 -0800 Subject: [PATCH 1/2] If ONEAPI_DEVICE_SELECTOR requested cpu/fpga devices, do NOT load L0/Cuda/Hip backends --- source/loader/ur_adapter_registry.hpp | 35 +++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/source/loader/ur_adapter_registry.hpp b/source/loader/ur_adapter_registry.hpp index 7df799ab1e..98ed911914 100644 --- a/source/loader/ur_adapter_registry.hpp +++ b/source/loader/ur_adapter_registry.hpp @@ -184,8 +184,13 @@ class AdapterRegistry { using EnvVarMap = std::map>; EnvVarMap mapODS = odsEnvMap.has_value() ? odsEnvMap.value() : EnvVarMap{{"*", {"*"}}}; + // if hasDeivce[1~5] corresponds to cpu, gpu, fpga, *, invalid device + // By default, each entry is false, it's flipped to true if requested by ONEAPI_DEVICE_SELECTOR + bool hasDevice[5]; + enum DeviceType {cpu, gpu, fpga, any, invalid}; for (auto &termPair : mapODS) { std::string backend = termPair.first; + std::vector devices = termPair.second; // TODO: Figure out how to process all ODS errors rather than returning // on the first error. if (backend.empty()) { @@ -223,6 +228,21 @@ class AdapterRegistry { backend); continue; } + + // Verify which devices are requested by ONEAPI_DEVICE_SELECTOR + for(int idev = 0; idev < 5; idev ++) hasDevice[idev] = false; + for(unsigned long int idev = 0; idev < devices.size(); idev++){ + if(strcmp(devices[idev].c_str(), "cpu") == 0) hasDevice[cpu] = true; + else if(strcmp(devices[idev].c_str(), "gpu") == 0) hasDevice[gpu] = true; + else if(strcmp(devices[idev].c_str(), "fpga") == 0) hasDevice[fpga] = true; + else if(strcmp(devices[idev].c_str(), "*") == 0) hasDevice[any] = true; + else { + hasDevice[invalid] = true; + logger::debug("ONEAPI_DEVICE_SELECTOR Pre-Filter with illegal " + "device '{}' ", + devices[idev]); + } + } // case-insensitive comparison by converting both tolower std::transform(platformBackendName.begin(), @@ -234,6 +254,7 @@ class AdapterRegistry { std::size_t nameFound = platformBackendName.find(backend); bool backendFound = nameFound != std::string::npos; + if (termType == AcceptFilter) { if (backend.front() != '*' && !backendFound) { logger::debug( @@ -242,8 +263,17 @@ class AdapterRegistry { backend, platformBackendName); acceptLibrary = false; continue; - } else if (backend.front() == '*' || backendFound) { - return UR_RESULT_SUCCESS; + }else if ( backend.front() == '*' && ( hasDevice[cpu] or hasDevice[fpga] ) && + (platformBackendName.find("level_zero") != std::string::npos || + platformBackendName.find("cuda") != std::string::npos || + platformBackendName.find("hip") != std::string::npos ) ){ + //level_zero, cuda, hip backends only supports gpu devices + //if no gpu devices are requested, reject the platformBackendName + acceptLibrary = false; + continue; + }else if ( backend.front() == '*' || backendFound ) { + acceptLibrary = true; + continue; } } else { if (backendFound || backend.front() == '*') { @@ -256,6 +286,7 @@ class AdapterRegistry { } } } + if (acceptLibrary) { return UR_RESULT_SUCCESS; } From c9c83ebe3de6c14a6df3c02b4e12fc6f32d29af2 Mon Sep 17 00:00:00 2001 From: Yantao Zhang <110424117+ytzhang1@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:06:34 -0800 Subject: [PATCH 2/2] Update ur_adapter_registry.hpp --- source/loader/ur_adapter_registry.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/loader/ur_adapter_registry.hpp b/source/loader/ur_adapter_registry.hpp index 98ed911914..b6a91876e4 100644 --- a/source/loader/ur_adapter_registry.hpp +++ b/source/loader/ur_adapter_registry.hpp @@ -263,7 +263,7 @@ class AdapterRegistry { backend, platformBackendName); acceptLibrary = false; continue; - }else if ( backend.front() == '*' && ( hasDevice[cpu] or hasDevice[fpga] ) && + }else if ( backend.front() == '*' && !( hasDevice[gpu] or hasDevice[any] ) && (platformBackendName.find("level_zero") != std::string::npos || platformBackendName.find("cuda") != std::string::npos || platformBackendName.find("hip") != std::string::npos ) ){