Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,17 @@ namespace ur_loader
%else:
<%param_replacements={}%>
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
%if 0 == i:
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
// extract platform's function pointer table
auto dditable = reinterpret_cast<${item['obj']}*>( ${item['pointer']}${item['name']} )->dditable;
auto ${th.make_pfn_name(n, tags, obj)} = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNINITIALIZED;

<%break%>
%endif
%endfor
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
%if 'range' in item:
<%
add_local = True
Expand All @@ -143,13 +146,15 @@ namespace ur_loader
for( size_t i = ${item['range'][0]}; i < ${item['range'][1]}; ++i )
${item['name']}Local[ i ] = reinterpret_cast<${item['obj']}*>( ${item['name']}[ i ] )->handle;
%else:
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
// convert loader handle to platform handle
%if item['optional']:
${item['name']} = ( ${item['name']} ) ? reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle : nullptr;
%else:
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
%endif
%endif
%endif

%endfor
// forward to device-platform
Expand All @@ -170,7 +175,7 @@ namespace ur_loader
%if item['release']:
// release loader handle
${item['factory']}.release( ${item['name']} );
%else:
%elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
try
{
%if 'range' in item:
Expand Down
129 changes: 9 additions & 120 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,6 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativePlatform = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativePlatform, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand Down Expand Up @@ -670,14 +662,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeDevice = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeDevice, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -696,17 +680,13 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->dditable;
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Device.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeDevice =
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->handle;

// convert loader handle to platform handle
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;

Expand Down Expand Up @@ -913,14 +893,6 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeContext, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -941,17 +913,13 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeContext)->dditable;
reinterpret_cast<ur_device_object_t *>(*phDevices)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Context.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeContext =
reinterpret_cast<ur_native_object_t *>(hNativeContext)->handle;

// convert loader handles to platform handles
auto phDevicesLocal = std::vector<ur_device_handle_t>(numDevices);
for (size_t i = 0; i < numDevices; ++i) {
Expand Down Expand Up @@ -1204,14 +1172,6 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeMem, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -1229,17 +1189,13 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnBufferCreateWithNativeHandle =
dditable->ur.Mem.pfnBufferCreateWithNativeHandle;
if (nullptr == pfnBufferCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -1279,17 +1235,13 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnImageCreateWithNativeHandle =
dditable->ur.Mem.pfnImageCreateWithNativeHandle;
if (nullptr == pfnImageCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -1525,14 +1477,6 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeSampler = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeSampler, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -1550,18 +1494,13 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Sampler.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeSampler =
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -2601,14 +2540,6 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeProgram = reinterpret_cast<ur_native_handle_t>(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bmyates : why is that we need this kind of logic in the L0 loader, of translating driver handles to L0 handles? Do you see any risks on removing this code in the UR loader?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loader handles are there to find the appropriate adapter-level function to call. The alternative is to either make all functions accept an adapter handle as one of the arguments or create a way to map normal handles to adapter handles.

This mechanism is not used when there's only one adapter, see #355.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative is to either make all functions accept an adapter handle as one of the arguments or create a way to map normal handles to adapter handles.

There's no need to. Most interop calls accept a context, which has the right adapter, and this patch exactly changes the behavior to find the adapter in the context, rather the native handle. The only notable exception is platform: #1068. That one certainly needs an adapter as an argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to.

I know, I agree. I was replying to Jamie's question.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @pbalcer , @alexbatashev . Ah, this is only for native handles. That's ok then.

+1

ur_native_factory.getInstance(*phNativeProgram, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -2626,18 +2557,13 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Program.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeProgram =
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -3085,14 +3011,6 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeKernel = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeKernel, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -3112,18 +3030,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Kernel.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeKernel =
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -3297,14 +3210,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeQueue = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeQueue, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -3323,17 +3228,13 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeQueue)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Queue.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeQueue = reinterpret_cast<ur_native_object_t *>(hNativeQueue)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -3570,14 +3471,6 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeEvent = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeEvent, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -3595,17 +3488,13 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeEvent)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Event.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeEvent = reinterpret_cast<ur_native_object_t *>(hNativeEvent)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down