@@ -33,6 +33,17 @@ ur_single_device_kernel_t::ur_single_device_kernel_t(ur_device_handle_t hDevice,
33
33
};
34
34
}
35
35
36
+ ur_result_t ur_single_device_kernel_t::setArgValue (uint32_t argIndex,
37
+ size_t argSize,
38
+ const void *pArgValue) {
39
+ return setArgValueOnZeKernel (hKernel.get (), argIndex, argSize, pArgValue);
40
+ }
41
+
42
+ ur_result_t ur_single_device_kernel_t::setArgPointer (uint32_t argIndex,
43
+ const void *pArgValue) {
44
+ return setArgValue (argIndex, sizeof (void *), &pArgValue);
45
+ }
46
+
36
47
ur_result_t ur_single_device_kernel_t::release () {
37
48
hKernel.reset ();
38
49
return UR_RESULT_SUCCESS;
@@ -187,19 +198,6 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
187
198
uint32_t argIndex, size_t argSize,
188
199
const ur_kernel_arg_value_properties_t * /* pProperties*/ ,
189
200
const void *pArgValue) {
190
-
191
- // OpenCL: "the arg_value pointer can be NULL or point to a NULL value
192
- // in which case a NULL value will be used as the value for the argument
193
- // declared as a pointer to global or constant memory in the kernel"
194
- //
195
- // We don't know the type of the argument but it seems that the only time
196
- // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument
197
- // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL.
198
- if (argSize == sizeof (void *) && pArgValue &&
199
- *(void **)(const_cast <void *>(pArgValue)) == nullptr ) {
200
- pArgValue = nullptr ;
201
- }
202
-
203
201
if (argIndex > zeCommonProperties->numKernelArgs - 1 ) {
204
202
return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
205
203
}
@@ -209,15 +207,8 @@ ur_result_t ur_kernel_handle_t_::setArgValue(
209
207
continue ;
210
208
}
211
209
212
- auto zeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
213
- (singleDeviceKernel.value ().hKernel .get (),
214
- argIndex, argSize, pArgValue));
215
-
216
- if (zeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
217
- return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
218
- } else if (zeResult != ZE_RESULT_SUCCESS) {
219
- return ze2urResult (zeResult);
220
- }
210
+ UR_CALL (setArgValueOnZeKernel (singleDeviceKernel.value ().hKernel .get (),
211
+ argIndex, argSize, pArgValue));
221
212
}
222
213
return UR_RESULT_SUCCESS;
223
214
}
@@ -281,7 +272,11 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
281
272
const size_t *pGlobalWorkOffset, uint32_t workDim, uint32_t groupSizeX,
282
273
uint32_t groupSizeY, uint32_t groupSizeZ,
283
274
ze_command_list_handle_t commandList, wait_list_view &waitListView) {
284
- auto hZeKernel = getZeHandle (hDevice);
275
+ auto &deviceKernelOpt = deviceKernels[deviceIndex (hDevice)];
276
+ if (!deviceKernelOpt.has_value ())
277
+ return UR_RESULT_ERROR_INVALID_KERNEL;
278
+ auto &deviceKernel = deviceKernelOpt.value ();
279
+ auto hZeKernel = deviceKernel.hKernel .get ();
285
280
286
281
if (pGlobalWorkOffset != NULL ) {
287
282
UR_CALL (
@@ -304,10 +299,17 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
304
299
zePtr = reinterpret_cast <void *>(hImage->getZeImage ());
305
300
}
306
301
}
307
- UR_CALL (setArgPointer (pending.argIndex , nullptr , zePtr));
302
+ // Set the argument only on this device's kernel.
303
+ UR_CALL (deviceKernel.setArgPointer (pending.argIndex , zePtr));
308
304
}
309
305
pending_allocations.clear ();
310
306
307
+ // Apply any pending raw pointer arguments (USM pointers) for this device.
308
+ for (auto &pending : pending_pointer_args) {
309
+ UR_CALL (deviceKernel.setArgPointer (pending.argIndex , pending.ptrArgValue ));
310
+ }
311
+ pending_pointer_args.clear ();
312
+
311
313
return UR_RESULT_SUCCESS;
312
314
}
313
315
@@ -322,6 +324,18 @@ ur_result_t ur_kernel_handle_t_::addPendingMemoryAllocation(
322
324
return UR_RESULT_SUCCESS;
323
325
}
324
326
327
+ ur_result_t
328
+ ur_kernel_handle_t_::addPendingPointerArgument (uint32_t argIndex,
329
+ const void *pArgValue) {
330
+ if (argIndex > zeCommonProperties->numKernelArgs - 1 ) {
331
+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
332
+ }
333
+
334
+ pending_pointer_args.push_back ({argIndex, pArgValue});
335
+
336
+ return UR_RESULT_SUCCESS;
337
+ }
338
+
325
339
std::vector<char > ur_kernel_handle_t_::getSourceAttributes () const {
326
340
uint32_t size;
327
341
ZE2UR_CALL_THROWS (zeKernelGetSourceAttributes,
@@ -408,14 +422,16 @@ ur_result_t urKernelSetArgPointer(
408
422
ur_kernel_handle_t hKernel, // /< [in] handle of the kernel object
409
423
uint32_t argIndex, // /< [in] argument index in range [0, num args - 1]
410
424
const ur_kernel_arg_pointer_properties_t
411
- *pProperties, // /< [in][optional] argument properties
425
+ * /* pProperties*/ , // /< [in][optional] argument properties
412
426
const void
413
427
*pArgValue // /< [in] argument value represented as matching arg type.
414
428
) try {
415
429
TRACK_SCOPE_LATENCY (" urKernelSetArgPointer" );
416
430
417
431
std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
418
- return hKernel->setArgPointer (argIndex, pProperties, pArgValue);
432
+ // Store the raw pointer value and defer setting the
433
+ // argument until we know the device where kernel is being submitted.
434
+ return hKernel->addPendingPointerArgument (argIndex, pArgValue);
419
435
} catch (...) {
420
436
return exceptionToResult (std::current_exception ());
421
437
}
0 commit comments