Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions examples/hello_world/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,36 @@ int main(int argc, char *argv[]) {
}
std::cout << "Platform initialized.\n";

uint32_t adapterCount = 0;
std::vector<ur_adapter_handle_t> adapters;
uint32_t platformCount = 0;
std::vector<ur_platform_handle_t> platforms;

status = urPlatformGet(1, nullptr, &platformCount);
status = urAdapterGet(0, nullptr, &adapterCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urAdapterGet failed with return code: " << status
<< std::endl;
return 1;
}
adapters.resize(adapterCount);
status = urAdapterGet(adapterCount, adapters.data(), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urAdapterGet failed with return code: " << status
<< std::endl;
return 1;
}

status = urPlatformGet(adapters.data(), adapterCount, 1, nullptr,
&platformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}

platforms.resize(platformCount);
status = urPlatformGet(platformCount, platforms.data(), nullptr);
status = urPlatformGet(adapters.data(), adapterCount, platformCount,
platforms.data(), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
Expand Down Expand Up @@ -98,6 +116,9 @@ int main(int argc, char *argv[]) {
}

out:
for (auto adapter : adapters) {
urAdapterRelease(adapter);
}
urTearDown(nullptr);
return status == UR_RESULT_SUCCESS ? 0 : 1;
}
101 changes: 88 additions & 13 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class ur_function_v(IntEnum):
BINDLESS_IMAGES_DESTROY_EXTERNAL_SEMAPHORE_EXP = 147## Enumerator for ::urBindlessImagesDestroyExternalSemaphoreExp
BINDLESS_IMAGES_WAIT_EXTERNAL_SEMAPHORE_EXP = 148 ## Enumerator for ::urBindlessImagesWaitExternalSemaphoreExp
BINDLESS_IMAGES_SIGNAL_EXTERNAL_SEMAPHORE_EXP = 149 ## Enumerator for ::urBindlessImagesSignalExternalSemaphoreExp
PLATFORM_GET_LAST_ERROR = 150 ## Enumerator for ::urPlatformGetLastError
ENQUEUE_USM_FILL_2D = 151 ## Enumerator for ::urEnqueueUSMFill2D
ENQUEUE_USM_MEMCPY_2D = 152 ## Enumerator for ::urEnqueueUSMMemcpy2D
VIRTUAL_MEM_GRANULARITY_GET_INFO = 153 ## Enumerator for ::urVirtualMemGranularityGetInfo
Expand All @@ -192,6 +191,11 @@ class ur_function_v(IntEnum):
LOADER_CONFIG_RETAIN = 174 ## Enumerator for ::urLoaderConfigRetain
LOADER_CONFIG_GET_INFO = 175 ## Enumerator for ::urLoaderConfigGetInfo
LOADER_CONFIG_ENABLE_LAYER = 176 ## Enumerator for ::urLoaderConfigEnableLayer
ADAPTER_RELEASE = 177 ## Enumerator for ::urAdapterRelease
ADAPTER_GET = 178 ## Enumerator for ::urAdapterGet
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -288,6 +292,11 @@ class ur_bool_t(c_ubyte):
class ur_loader_config_handle_t(c_void_p):
pass

###############################################################################
## @brief Handle of an adapter instance
class ur_adapter_handle_t(c_void_p):
pass

###############################################################################
## @brief Handle of a platform instance
class ur_platform_handle_t(c_void_p):
Expand Down Expand Up @@ -501,6 +510,36 @@ def __str__(self):
return str(ur_loader_config_info_v(self.value))


###############################################################################
## @brief Supported adapter info
class ur_adapter_info_v(IntEnum):
BACKEND = 0 ## [::ur_adapter_backend_t] Identifies the native backend supported by
## the adapter.
REFERENCE_COUNT = 1 ## [uint32_t] Reference count of the adapter.
## The reference count returned should be considered immediately stale.
## It is unsuitable for general use in applications. This feature is
## provided for identifying memory leaks.

class ur_adapter_info_t(c_int):
def __str__(self):
return str(ur_adapter_info_v(self.value))


###############################################################################
## @brief Identifies backend of the adapter
class ur_adapter_backend_v(IntEnum):
UNKNOWN = 0 ## The backend is not a recognized one
LEVEL_ZERO = 1 ## The backend is Level Zero
OPENCL = 2 ## The backend is OpenCL
CUDA = 3 ## The backend is CUDA
HIP = 4 ## The backend is HIP
NATIVE_CPU = 5 ## The backend is Native CPU

class ur_adapter_backend_t(c_int):
def __str__(self):
return str(ur_adapter_backend_v(self.value))


###############################################################################
## @brief Supported platform info
class ur_platform_info_v(IntEnum):
Expand Down Expand Up @@ -2273,9 +2312,9 @@ class ur_loader_config_dditable_t(Structure):
###############################################################################
## @brief Function-pointer for urPlatformGet
if __use_win_types:
_urPlatformGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
_urPlatformGet_t = WINFUNCTYPE( ur_result_t, POINTER(ur_adapter_handle_t), c_ulong, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
else:
_urPlatformGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
_urPlatformGet_t = CFUNCTYPE( ur_result_t, POINTER(ur_adapter_handle_t), c_ulong, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urPlatformGetInfo
Expand All @@ -2298,13 +2337,6 @@ class ur_loader_config_dditable_t(Structure):
else:
_urPlatformCreateWithNativeHandle_t = CFUNCTYPE( ur_result_t, ur_native_handle_t, POINTER(ur_platform_native_properties_t), POINTER(ur_platform_handle_t) )

###############################################################################
## @brief Function-pointer for urPlatformGetLastError
if __use_win_types:
_urPlatformGetLastError_t = WINFUNCTYPE( ur_result_t, ur_platform_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urPlatformGetLastError_t = CFUNCTYPE( ur_result_t, ur_platform_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urPlatformGetApiVersion
if __use_win_types:
Expand All @@ -2328,7 +2360,6 @@ class ur_platform_dditable_t(Structure):
("pfnGetInfo", c_void_p), ## _urPlatformGetInfo_t
("pfnGetNativeHandle", c_void_p), ## _urPlatformGetNativeHandle_t
("pfnCreateWithNativeHandle", c_void_p), ## _urPlatformCreateWithNativeHandle_t
("pfnGetLastError", c_void_p), ## _urPlatformGetLastError_t
("pfnGetApiVersion", c_void_p), ## _urPlatformGetApiVersion_t
("pfnGetBackendOption", c_void_p) ## _urPlatformGetBackendOption_t
]
Expand Down Expand Up @@ -3565,13 +3596,53 @@ class ur_usm_p2p_exp_dditable_t(Structure):
else:
_urTearDown_t = CFUNCTYPE( ur_result_t, c_void_p )

###############################################################################
## @brief Function-pointer for urAdapterGet
if __use_win_types:
_urAdapterGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
else:
_urAdapterGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urAdapterRelease
if __use_win_types:
_urAdapterRelease_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRelease_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterRetain
if __use_win_types:
_urAdapterRetain_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRetain_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterGetLastError
if __use_win_types:
_urAdapterGetLastError_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urAdapterGetLastError_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urAdapterGetInfo
if __use_win_types:
_urAdapterGetInfo_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urAdapterGetInfo_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )


###############################################################################
## @brief Table of Global functions pointers
class ur_global_dditable_t(Structure):
_fields_ = [
("pfnInit", c_void_p), ## _urInit_t
("pfnTearDown", c_void_p) ## _urTearDown_t
("pfnTearDown", c_void_p), ## _urTearDown_t
("pfnAdapterGet", c_void_p), ## _urAdapterGet_t
("pfnAdapterRelease", c_void_p), ## _urAdapterRelease_t
("pfnAdapterRetain", c_void_p), ## _urAdapterRetain_t
("pfnAdapterGetLastError", c_void_p), ## _urAdapterGetLastError_t
("pfnAdapterGetInfo", c_void_p) ## _urAdapterGetInfo_t
]

###############################################################################
Expand Down Expand Up @@ -3768,7 +3839,6 @@ def __init__(self, version : ur_api_version_t):
self.urPlatformGetInfo = _urPlatformGetInfo_t(self.__dditable.Platform.pfnGetInfo)
self.urPlatformGetNativeHandle = _urPlatformGetNativeHandle_t(self.__dditable.Platform.pfnGetNativeHandle)
self.urPlatformCreateWithNativeHandle = _urPlatformCreateWithNativeHandle_t(self.__dditable.Platform.pfnCreateWithNativeHandle)
self.urPlatformGetLastError = _urPlatformGetLastError_t(self.__dditable.Platform.pfnGetLastError)
self.urPlatformGetApiVersion = _urPlatformGetApiVersion_t(self.__dditable.Platform.pfnGetApiVersion)
self.urPlatformGetBackendOption = _urPlatformGetBackendOption_t(self.__dditable.Platform.pfnGetBackendOption)

Expand Down Expand Up @@ -4048,6 +4118,11 @@ def __init__(self, version : ur_api_version_t):
# attach function interface to function address
self.urInit = _urInit_t(self.__dditable.Global.pfnInit)
self.urTearDown = _urTearDown_t(self.__dditable.Global.pfnTearDown)
self.urAdapterGet = _urAdapterGet_t(self.__dditable.Global.pfnAdapterGet)
self.urAdapterRelease = _urAdapterRelease_t(self.__dditable.Global.pfnAdapterRelease)
self.urAdapterRetain = _urAdapterRetain_t(self.__dditable.Global.pfnAdapterRetain)
self.urAdapterGetLastError = _urAdapterGetLastError_t(self.__dditable.Global.pfnAdapterGetLastError)
self.urAdapterGetInfo = _urAdapterGetInfo_t(self.__dditable.Global.pfnAdapterGetInfo)

# call driver to get function pointers
VirtualMem = ur_virtual_mem_dditable_t()
Expand Down
Loading