diff --git a/unified-runtime/scripts/templates/ldrddi.cpp.mako b/unified-runtime/scripts/templates/ldrddi.cpp.mako index c070e281a6b9c..be8ef1a9d1f09 100644 --- a/unified-runtime/scripts/templates/ldrddi.cpp.mako +++ b/unified-runtime/scripts/templates/ldrddi.cpp.mako @@ -56,9 +56,13 @@ namespace ur_loader if (platform.initStatus != ${X}_RESULT_SUCCESS) continue; + auto *${th.make_pfn_name(n, tags, obj)} = platform.dditable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}; + if (*${th.make_pfn_name(n, tags, obj)} == nullptr) + return ${X}_RESULT_ERROR_UNINITIALIZED; + uint32_t adapter; ur_adapter_handle_t *adapterHandle = numAdapters < NumEntries ? &${obj['params'][1]['name']}[numAdapters] : nullptr; - platform.dditable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( 1, adapterHandle, &adapter ); + ${th.make_pfn_name(n, tags, obj)}( 1, adapterHandle, &adapter ); numAdapters += adapter; } @@ -129,6 +133,7 @@ ${tbl['export']['name']}( if(platform.initStatus != ${X}_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast<${tbl['pfn']}>( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "${tbl['export']['name']}")); if(!getTable) diff --git a/unified-runtime/source/common/linux/ur_lib_loader.cpp b/unified-runtime/source/common/linux/ur_lib_loader.cpp index bb49018e77f3b..46e46863876f9 100644 --- a/unified-runtime/source/common/linux/ur_lib_loader.cpp +++ b/unified-runtime/source/common/linux/ur_lib_loader.cpp @@ -78,7 +78,16 @@ LibLoader::loadAdapterLibrary(const char *name) { } void *LibLoader::getFunctionPtr(HMODULE handle, const char *func_name) { - return dlsym(handle, func_name); + // Clear any existing error + dlerror(); + + void *ptr = dlsym(handle, func_name); + const char *err = dlerror(); + if (err) { + UR_LOG(ERR, "dlsym failed to load function '{}': {}", func_name, err); + } + + return ptr; } } // namespace ur_loader diff --git a/unified-runtime/source/loader/ur_ldrddi.cpp b/unified-runtime/source/loader/ur_ldrddi.cpp index 3a366b990055c..354c79bf3267a 100644 --- a/unified-runtime/source/loader/ur_ldrddi.cpp +++ b/unified-runtime/source/loader/ur_ldrddi.cpp @@ -39,10 +39,14 @@ __urdlllocal ur_result_t UR_APICALL urAdapterGet( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto *pfnGet = platform.dditable.Adapter.pfnGet; + if (*pfnGet == nullptr) + return UR_RESULT_ERROR_UNINITIALIZED; + uint32_t adapter; ur_adapter_handle_t *adapterHandle = numAdapters < NumEntries ? &phAdapters[numAdapters] : nullptr; - platform.dditable.Adapter.pfnGet(1, adapterHandle, &adapter); + pfnGet(1, adapterHandle, &adapter); numAdapters += adapter; } @@ -6237,6 +6241,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetAdapterProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetAdapterProcAddrTable")); @@ -6296,6 +6301,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetBindlessImagesExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr( platform.handle.get(), "urGetBindlessImagesExpProcAddrTable")); @@ -6392,6 +6398,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr( platform.handle.get(), "urGetCommandBufferExpProcAddrTable")); @@ -6484,6 +6491,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetContextProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetContextProcAddrTable")); @@ -6543,6 +6551,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetEnqueueProcAddrTable")); @@ -6624,6 +6633,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetEnqueueExpProcAddrTable")); @@ -6687,6 +6697,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEventProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetEventProcAddrTable")); @@ -6747,6 +6758,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetGraphExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetGraphExpProcAddrTable")); @@ -6805,6 +6817,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetIPCExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetIPCExpProcAddrTable")); @@ -6860,6 +6873,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetKernelProcAddrTable")); @@ -6932,6 +6946,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetMemProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetMemProcAddrTable")); @@ -6995,6 +7010,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetMemoryExportExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr( platform.handle.get(), "urGetMemoryExportExpProcAddrTable")); @@ -7053,6 +7069,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetPhysicalMemProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetPhysicalMemProcAddrTable")); @@ -7109,6 +7126,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetPlatformProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetPlatformProcAddrTable")); @@ -7167,6 +7185,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetProgramProcAddrTable")); @@ -7235,6 +7254,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetProgramExpProcAddrTable")); @@ -7291,6 +7311,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetQueueProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetQueueProcAddrTable")); @@ -7351,6 +7372,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetQueueExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetQueueExpProcAddrTable")); @@ -7409,6 +7431,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetSamplerProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetSamplerProcAddrTable")); @@ -7467,6 +7490,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetUSMProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetUSMProcAddrTable")); @@ -7527,6 +7551,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetUSMExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetUSMExpProcAddrTable")); @@ -7591,6 +7616,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetUsmP2PExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetUsmP2PExpProcAddrTable")); @@ -7649,6 +7675,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetVirtualMemProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetVirtualMemProcAddrTable")); @@ -7709,6 +7736,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetDeviceProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetDeviceProcAddrTable")); @@ -7771,6 +7799,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetDeviceExpProcAddrTable( if (platform.initStatus != UR_RESULT_SUCCESS) continue; + auto getTable = reinterpret_cast( ur_loader::LibLoader::getFunctionPtr(platform.handle.get(), "urGetDeviceExpProcAddrTable"));