Skip to content

Commit 86451eb

Browse files
authored
[UR][L0 v2] Set pointer kernel arguments only for queue's associated device (#20179)
1 parent 33fceed commit 86451eb

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

unified-runtime/source/adapters/level_zero/v2/kernel.cpp

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ ur_single_device_kernel_t::ur_single_device_kernel_t(ur_device_handle_t hDevice,
3333
};
3434
}
3535

36+
ur_result_t ur_single_device_kernel_t::setArgValue(uint32_t argIndex,
37+
size_t argSize,
38+
const void *pArgValue) {
39+
return setArgValueOnZeKernel(hKernel.get(), argIndex, argSize, pArgValue);
40+
}
41+
42+
ur_result_t ur_single_device_kernel_t::setArgPointer(uint32_t argIndex,
43+
const void *pArgValue) {
44+
return setArgValue(argIndex, sizeof(void *), &pArgValue);
45+
}
46+
3647
ur_result_t ur_single_device_kernel_t::release() {
3748
hKernel.reset();
3849
return UR_RESULT_SUCCESS;
@@ -187,19 +198,6 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
187198
uint32_t argIndex, size_t argSize,
188199
const ur_kernel_arg_value_properties_t * /*pProperties*/,
189200
const void *pArgValue) {
190-
191-
// OpenCL: "the arg_value pointer can be NULL or point to a NULL value
192-
// in which case a NULL value will be used as the value for the argument
193-
// declared as a pointer to global or constant memory in the kernel"
194-
//
195-
// We don't know the type of the argument but it seems that the only time
196-
// SYCL RT would send a pointer to NULL in 'arg_value' is when the argument
197-
// is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL.
198-
if (argSize == sizeof(void *) && pArgValue &&
199-
*(void **)(const_cast<void *>(pArgValue)) == nullptr) {
200-
pArgValue = nullptr;
201-
}
202-
203201
if (argIndex > zeCommonProperties->numKernelArgs - 1) {
204202
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
205203
}
@@ -209,15 +207,8 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
209207
continue;
210208
}
211209

212-
auto zeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue,
213-
(singleDeviceKernel.value().hKernel.get(),
214-
argIndex, argSize, pArgValue));
215-
216-
if (zeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
217-
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
218-
} else if (zeResult != ZE_RESULT_SUCCESS) {
219-
return ze2urResult(zeResult);
220-
}
210+
UR_CALL(setArgValueOnZeKernel(singleDeviceKernel.value().hKernel.get(),
211+
argIndex, argSize, pArgValue));
221212
}
222213
return UR_RESULT_SUCCESS;
223214
}
@@ -281,7 +272,11 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
281272
const size_t *pGlobalWorkOffset, uint32_t workDim, uint32_t groupSizeX,
282273
uint32_t groupSizeY, uint32_t groupSizeZ,
283274
ze_command_list_handle_t commandList, wait_list_view &waitListView) {
284-
auto hZeKernel = getZeHandle(hDevice);
275+
auto &deviceKernelOpt = deviceKernels[deviceIndex(hDevice)];
276+
if (!deviceKernelOpt.has_value())
277+
return UR_RESULT_ERROR_INVALID_KERNEL;
278+
auto &deviceKernel = deviceKernelOpt.value();
279+
auto hZeKernel = deviceKernel.hKernel.get();
285280

286281
if (pGlobalWorkOffset != NULL) {
287282
UR_CALL(
@@ -304,10 +299,17 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
304299
zePtr = reinterpret_cast<void *>(hImage->getZeImage());
305300
}
306301
}
307-
UR_CALL(setArgPointer(pending.argIndex, nullptr, zePtr));
302+
// Set the argument only on this device's kernel.
303+
UR_CALL(deviceKernel.setArgPointer(pending.argIndex, zePtr));
308304
}
309305
pending_allocations.clear();
310306

307+
// Apply any pending raw pointer arguments (USM pointers) for this device.
308+
for (auto &pending : pending_pointer_args) {
309+
UR_CALL(deviceKernel.setArgPointer(pending.argIndex, pending.ptrArgValue));
310+
}
311+
pending_pointer_args.clear();
312+
311313
return UR_RESULT_SUCCESS;
312314
}
313315

@@ -322,6 +324,18 @@ ur_result_t ur_kernel_handle_t_::addPendingMemoryAllocation(
322324
return UR_RESULT_SUCCESS;
323325
}
324326

327+
ur_result_t
328+
ur_kernel_handle_t_::addPendingPointerArgument(uint32_t argIndex,
329+
const void *pArgValue) {
330+
if (argIndex > zeCommonProperties->numKernelArgs - 1) {
331+
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
332+
}
333+
334+
pending_pointer_args.push_back({argIndex, pArgValue});
335+
336+
return UR_RESULT_SUCCESS;
337+
}
338+
325339
std::vector<char> ur_kernel_handle_t_::getSourceAttributes() const {
326340
uint32_t size;
327341
ZE2UR_CALL_THROWS(zeKernelGetSourceAttributes,
@@ -408,14 +422,16 @@ ur_result_t urKernelSetArgPointer(
408422
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
409423
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
410424
const ur_kernel_arg_pointer_properties_t
411-
*pProperties, ///< [in][optional] argument properties
425+
* /*pProperties*/, ///< [in][optional] argument properties
412426
const void
413427
*pArgValue ///< [in] argument value represented as matching arg type.
414428
) try {
415429
TRACK_SCOPE_LATENCY("urKernelSetArgPointer");
416430

417431
std::scoped_lock<ur_shared_mutex> guard(hKernel->Mutex);
418-
return hKernel->setArgPointer(argIndex, pProperties, pArgValue);
432+
// Store the raw pointer value and defer setting the
433+
// argument until we know the device where kernel is being submitted.
434+
return hKernel->addPendingPointerArgument(argIndex, pArgValue);
419435
} catch (...) {
420436
return exceptionToResult(std::current_exception());
421437
}

unified-runtime/source/adapters/level_zero/v2/kernel.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ struct ur_single_device_kernel_t {
2121
ze_kernel_handle_t hKernel, bool ownZeHandle);
2222
ur_result_t release();
2323

24+
// Set argument value on this device's ze kernel only.
25+
ur_result_t setArgValue(uint32_t argIndex, size_t argSize,
26+
const void *pArgValue);
27+
28+
// Convenience for pointer args: sets a pointer-sized argument on this
29+
// device's ze kernel.
30+
ur_result_t setArgPointer(uint32_t argIndex, const void *pArgValue);
31+
2432
ur_device_handle_t hDevice;
2533
v2::raii::ze_kernel_handle_t hKernel;
2634
mutable ZeCache<ZeStruct<ze_kernel_properties_t>> zeKernelProperties;
@@ -82,6 +90,10 @@ struct ur_kernel_handle_t_ : ur_object {
8290
ur_result_t
8391
addPendingMemoryAllocation(pending_memory_allocation_t allocation);
8492

93+
// Add a pending pointer argument for which device is not yet known.
94+
ur_result_t addPendingPointerArgument(uint32_t argIndex,
95+
const void *pArgValue);
96+
8597
// Set all required values for the kernel before submission (including pending
8698
// memory allocations).
8799
ur_result_t prepareForSubmission(ur_context_handle_t hContext,
@@ -92,6 +104,9 @@ struct ur_kernel_handle_t_ : ur_object {
92104
ze_command_list_handle_t cmdList,
93105
wait_list_view &waitListView);
94106

107+
// Get context of the kernel.
108+
ur_context_handle_t getContext() const { return hProgram->Context; }
109+
95110
ur::RefCount RefCount;
96111

97112
private:
@@ -115,6 +130,14 @@ struct ur_kernel_handle_t_ : ur_object {
115130

116131
std::vector<pending_memory_allocation_t> pending_allocations;
117132

133+
struct pending_pointer_arg_t {
134+
uint32_t argIndex;
135+
const void *ptrArgValue;
136+
};
137+
138+
// Pointer arguments that need to be applied per-device at submission time.
139+
std::vector<pending_pointer_arg_t> pending_pointer_args;
140+
118141
void completeInitialization();
119142

120143
// pointer to any non-null kernel in deviceKernels

0 commit comments

Comments
 (0)