diff --git a/source/adapters/level_zero/adapter.cpp b/source/adapters/level_zero/adapter.cpp index 0605b9a40c..5ae1d52e7b 100644 --- a/source/adapters/level_zero/adapter.cpp +++ b/source/adapters/level_zero/adapter.cpp @@ -43,15 +43,31 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try { } std::vector ZeDrivers; + std::vector ZeDevices; ZeDrivers.resize(ZeDriverCount); ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data())); for (uint32_t I = 0; I < ZeDriverCount; ++I) { - auto platform = std::make_unique(ZeDrivers[I]); - UR_CALL(platform->initialize()); - - // Save a copy in the cache for future uses. - platforms.push_back(std::move(platform)); + ze_device_properties_t device_properties{}; + device_properties.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + uint32_t ZeDeviceCount = 0; + ZE2UR_CALL(zeDeviceGet, (ZeDrivers[I], &ZeDeviceCount, nullptr)); + ZeDevices.resize(ZeDeviceCount); + ZE2UR_CALL(zeDeviceGet, (ZeDrivers[I], &ZeDeviceCount, ZeDevices.data())); + // Check if this driver has GPU Devices + for (uint32_t D = 0; D < ZeDeviceCount; ++D) { + ZE2UR_CALL(zeDeviceGetProperties, (ZeDevices[D], &device_properties)); + + if (ZE_DEVICE_TYPE_GPU == device_properties.type) { + // If this Driver is a GPU, save it as a usable platform. + auto platform = std::make_unique(ZeDrivers[I]); + UR_CALL(platform->initialize()); + + // Save a copy in the cache for future uses. + platforms.push_back(std::move(platform)); + break; + } + } } return UR_RESULT_SUCCESS; } catch (...) { @@ -105,8 +121,16 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() // We must only initialize the driver once, even if urPlatformGet() is // called multiple times. Declaring the return value as "static" ensures // it's only called once. - GlobalAdapter->ZeResult = - ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY)); + + // Init with all flags set to enable for all driver types to be init in + // the application. + ze_init_flags_t L0InitFlags = ZE_INIT_FLAG_GPU_ONLY; + if (UrL0InitAllDrivers) { + L0InitFlags |= ZE_INIT_FLAG_VPU_ONLY; + } + logger::debug("\nzeInit with flags value of {}\n", + static_cast(L0InitFlags)); + GlobalAdapter->ZeResult = ZE_CALL_NOCHECK(zeInit, (L0InitFlags)); } assert(GlobalAdapter->ZeResult != std::nullopt); // verify that level-zero is initialized diff --git a/source/adapters/level_zero/common.hpp b/source/adapters/level_zero/common.hpp index a0f94a750e..5784d5bf78 100644 --- a/source/adapters/level_zero/common.hpp +++ b/source/adapters/level_zero/common.hpp @@ -207,6 +207,15 @@ const int UrL0LeaksDebug = [] { return std::atoi(UrRet); }(); +// Enable for UR L0 Adapter to Init all L0 Drivers on the system with filtering +// in place for only currently used Drivers. +const int UrL0InitAllDrivers = [] { + const char *UrRet = std::getenv("UR_L0_INIT_ALL_DRIVERS"); + if (!UrRet) + return 0; + return std::atoi(UrRet); +}(); + // Controls Level Zero calls serialization to w/a Level Zero driver being not MT // ready. Recognized values (can be used as a bit mask): enum {