Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2062,15 +2062,15 @@ typedef struct ur_device_native_properties_t {
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// + `NULL == hAdapter`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevice`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
UR_APIEXPORT ur_result_t UR_APICALL
urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t *phDevice ///< [out] pointer to the handle of the device object created.
);
Expand Down Expand Up @@ -11972,7 +11972,7 @@ typedef struct ur_device_get_native_handle_params_t {
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_device_create_with_native_handle_params_t {
ur_native_handle_t *phNativeDevice;
ur_platform_handle_t *phPlatform;
ur_adapter_handle_t *phAdapter;
const ur_device_native_properties_t **ppProperties;
ur_device_handle_t **pphDevice;
} ur_device_create_with_native_handle_params_t;
Expand Down
2 changes: 1 addition & 1 deletion include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -2373,7 +2373,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnDeviceGetNativeHandle_t)(
/// @brief Function-pointer for urDeviceCreateWithNativeHandle
typedef ur_result_t(UR_APICALL *ur_pfnDeviceCreateWithNativeHandle_t)(
ur_native_handle_t,
ur_platform_handle_t,
ur_adapter_handle_t,
const ur_device_native_properties_t *,
ur_device_handle_t *);

Expand Down
4 changes: 2 additions & 2 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17357,10 +17357,10 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
*(params->phNativeDevice)));

os << ", ";
os << ".hPlatform = ";
os << ".hAdapter = ";

ur::details::printPtr(os,
*(params->phPlatform));
*(params->phAdapter));

os << ", ";
os << ".pProperties = ";
Expand Down
6 changes: 3 additions & 3 deletions scripts/core/device.yml
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,9 @@ params:
- type: $x_native_handle_t
name: hNativeDevice
desc: "[in][nocheck] the native handle of the device."
- type: $x_platform_handle_t
name: hPlatform
desc: "[in] handle of the platform instance"
- type: $x_adapter_handle_t
name: hAdapter
desc: "[in] handle of the adapter to which `hNativeDevice` belongs"
- type: const $x_device_native_properties_t*
name: pProperties
desc: "[in][optional] pointer to native device properties struct."
Expand Down
17 changes: 3 additions & 14 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1185,27 +1185,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
/// \return TBD

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
const ur_device_native_properties_t *pProperties,
ur_native_handle_t hNativeDevice,
[[maybe_unused]] ur_adapter_handle_t hAdapter,
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
ur_device_handle_t *phDevice) {
std::ignore = pProperties;

CUdevice CuDevice = static_cast<CUdevice>(hNativeDevice);

auto IsDevice = [=](std::unique_ptr<ur_device_handle_t_> &Dev) {
return Dev->get() == CuDevice;
};

// If a platform is provided just check if the device is in it
if (hPlatform) {
auto SearchRes = std::find_if(begin(hPlatform->Devices),
end(hPlatform->Devices), IsDevice);
if (SearchRes != end(hPlatform->Devices)) {
*phDevice = SearchRes->get();
return UR_RESULT_SUCCESS;
}
}

// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
Expand Down
13 changes: 2 additions & 11 deletions source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
ur_native_handle_t hNativeDevice,
[[maybe_unused]] ur_adapter_handle_t hAdapter,
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
ur_device_handle_t *phDevice) {
// We can't cast between ur_native_handle_t and hipDevice_t, so memcpy the
Expand All @@ -1000,16 +1001,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
return Dev->get() == HIPDevice;
};

// If a platform is provided just check if the device is in it
if (hPlatform) {
auto SearchRes = std::find_if(begin(hPlatform->Devices),
end(hPlatform->Devices), IsDevice);
if (SearchRes != end(hPlatform->Devices)) {
*phDevice = SearchRes->get();
return UR_RESULT_SUCCESS;
}
}

// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
Expand Down
12 changes: 3 additions & 9 deletions source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,14 +1602,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t NativeDevice, ///< [in] the native handle of the device.
ur_platform_handle_t Platform, ///< [in] handle of the platform instance
const ur_device_native_properties_t
[[maybe_unused]] ur_adapter_handle_t
Adapter, ///< [in] handle of the platform instance
[[maybe_unused]] const ur_device_native_properties_t
*Properties, ///< [in][optional] pointer to native device properties
///< struct.
ur_device_handle_t
*Device ///< [out] pointer to the handle of the device object created.
) {
std::ignore = Properties;
auto ZeDevice = ur_cast<ze_device_handle_t>(NativeDevice);

// The SYCL spec requires that the set of devices must remain fixed for the
Expand All @@ -1622,12 +1622,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
if (const auto *platforms = GlobalAdapter->PlatformCache->get_value()) {
for (const auto &p : *platforms) {
Dev = p->getDeviceFromNativeHandle(ZeDevice);
if (Dev) {
// Check that the input Platform, if was given, matches the found one.
UR_ASSERT(!Platform || Platform == p.get(),
UR_RESULT_ERROR_INVALID_PLATFORM);
break;
}
}
} else {
return GlobalAdapter->PlatformCache->get_error();
Expand Down
5 changes: 3 additions & 2 deletions source/adapters/mock/ur_mockddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -930,7 +931,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

ur_device_create_with_native_handle_params_t params = {
&hNativeDevice, &hPlatform, &pProperties, &phDevice};
&hNativeDevice, &hAdapter, &pProperties, &phDevice};

auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(
mock::getCallbacks().get_before_callback(
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/native_cpu/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
ur_native_handle_t hNativeDevice, ur_adapter_handle_t hAdapter,
const ur_device_native_properties_t *pProperties,
ur_device_handle_t *phDevice) {
std::ignore = hNativeDevice;
std::ignore = hPlatform;
std::ignore = hAdapter;
std::ignore = pProperties;
std::ignore = phDevice;

Expand Down
2 changes: 1 addition & 1 deletion source/adapters/opencl/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t,
ur_native_handle_t hNativeDevice, ur_adapter_handle_t,
const ur_device_native_properties_t *, ur_device_handle_t *phDevice) {

*phDevice = reinterpret_cast<ur_device_handle_t>(hNativeDevice);
Expand Down
7 changes: 4 additions & 3 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -719,14 +720,14 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
}

ur_device_create_with_native_handle_params_t params = {
&hNativeDevice, &hPlatform, &pProperties, &phDevice};
&hNativeDevice, &hAdapter, &pProperties, &phDevice};
uint64_t instance =
getContext()->notify_begin(UR_FUNCTION_DEVICE_CREATE_WITH_NATIVE_HANDLE,
"urDeviceCreateWithNativeHandle", &params);

getContext()->logger.info("---> urDeviceCreateWithNativeHandle");

ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hPlatform,
ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hAdapter,
pProperties, phDevice);

getContext()->notify_end(UR_FUNCTION_DEVICE_CREATE_WITH_NATIVE_HANDLE,
Expand Down
12 changes: 9 additions & 3 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -733,7 +734,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
}

if (getContext()->enableParameterValidation) {
if (NULL == hPlatform) {
if (NULL == hAdapter) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

Expand All @@ -742,7 +743,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
}
}

ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hPlatform,
if (getContext()->enableLifetimeValidation &&
!getContext()->refCountContext->isReferenceValid(hAdapter)) {
getContext()->refCountContext->logInvalidReference(hAdapter);
}

ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hAdapter,
pProperties, phDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
Expand Down
10 changes: 5 additions & 5 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -775,19 +776,18 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
[[maybe_unused]] auto context = getContext();

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
auto dditable = reinterpret_cast<ur_adapter_object_t *>(hAdapter)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Device.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;
hAdapter = reinterpret_cast<ur_adapter_object_t *>(hAdapter)->handle;

// forward to device-platform
result = pfnCreateWithNativeHandle(hNativeDevice, hPlatform, pProperties,
result = pfnCreateWithNativeHandle(hNativeDevice, hAdapter, pProperties,
phDevice);

if (UR_RESULT_SUCCESS != result) {
Expand Down
7 changes: 4 additions & 3 deletions source/loader/ur_libapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1135,15 +1135,16 @@ ur_result_t UR_APICALL urDeviceGetNativeHandle(
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// + `NULL == hAdapter`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevice`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -1155,7 +1156,7 @@ ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
return UR_RESULT_ERROR_UNINITIALIZED;
}

return pfnCreateWithNativeHandle(hNativeDevice, hPlatform, pProperties,
return pfnCreateWithNativeHandle(hNativeDevice, hAdapter, pProperties,
phDevice);
} catch (...) {
return exceptionToResult(std::current_exception());
Expand Down
5 changes: 3 additions & 2 deletions source/ur_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,15 +997,16 @@ ur_result_t UR_APICALL urDeviceGetNativeHandle(
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// + `NULL == hAdapter`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevice`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand Down
2 changes: 1 addition & 1 deletion test/adapters/cuda/urDeviceCreateWithNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ TEST_F(urCudaDeviceCreateWithNativeHandle, Success) {

ur_native_handle_t nativeCuda = static_cast<ur_native_handle_t>(cudaDevice);
ur_device_handle_t urDevice;
ASSERT_SUCCESS(urDeviceCreateWithNativeHandle(nativeCuda, platform, nullptr,
ASSERT_SUCCESS(urDeviceCreateWithNativeHandle(nativeCuda, adapter, nullptr,
&urDevice));
}
8 changes: 4 additions & 4 deletions test/conformance/device/urDeviceCreateWithNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, Success) {
// and perform some query on it to verify that it works.
ur_device_handle_t dev = nullptr;
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(urDeviceCreateWithNativeHandle(
native_handle, platform, nullptr, &dev));
native_handle, adapter, nullptr, &dev));
ASSERT_NE(dev, nullptr);

uint32_t dev_id = 0;
Expand All @@ -41,7 +41,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, SuccessWithOwnedNativeHandle) {
ur_device_native_properties_t props{
UR_STRUCTURE_TYPE_DEVICE_NATIVE_PROPERTIES, nullptr, true};
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(urDeviceCreateWithNativeHandle(
native_handle, platform, &props, &dev));
native_handle, adapter, &props, &dev));
ASSERT_NE(dev, nullptr);

uint32_t ref_count = 0;
Expand All @@ -64,7 +64,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, SuccessWithUnOwnedNativeHandle) {
ur_device_native_properties_t props{
UR_STRUCTURE_TYPE_DEVICE_NATIVE_PROPERTIES, nullptr, false};
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(urDeviceCreateWithNativeHandle(
native_handle, platform, &props, &dev));
native_handle, adapter, &props, &dev));
ASSERT_NE(dev, nullptr);

uint32_t ref_count = 0;
Expand Down Expand Up @@ -93,7 +93,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, InvalidNullPointerDevice) {
ASSERT_SUCCESS(urDeviceGetNativeHandle(device, &native_handle));

ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER,
urDeviceCreateWithNativeHandle(native_handle, platform,
urDeviceCreateWithNativeHandle(native_handle, adapter,
nullptr, nullptr));
}
}