Skip to content

Commit 4cef72d

Browse files
committed
Revert "Revert "[UR][SYCL] Introduce UR api to set kernel args + launch in one call." (intel#19661)"
This reverts commit d25d6d6. Signed-off-by: Lukasz Dorau <lukasz.dorau@intel.com>
1 parent 889db98 commit 4cef72d

File tree

84 files changed

+3372
-429
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+3372
-429
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 165 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,14 +2303,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303
}
23042304
}
23052305

2306-
// Sets arguments for a given kernel and device based on the argument type.
2307-
// Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2308-
// extension.
2309-
static void SetArgBasedOnType(
2310-
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2306+
// Gets UR argument struct for a given kernel and device based on the argument
2307+
// type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2308+
// the graphs extension (LaunchWithArgs for graphs is planned future work).
2309+
static void GetUrArgsBasedOnType(
23112310
device_image_impl *DeviceImageImpl,
23122311
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2313-
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2312+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2313+
std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
23142314
switch (Arg.MType) {
23152315
case kernel_param_kind_t::kind_dynamic_work_group_memory:
23162316
break;
@@ -2330,52 +2330,61 @@ static void SetArgBasedOnType(
23302330
getMemAllocationFunc
23312331
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
23322332
: nullptr;
2333-
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2334-
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2335-
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2336-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2337-
&MemObjData, MemArg);
2333+
ur_exp_kernel_arg_value_t Value = {};
2334+
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
2335+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2336+
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2337+
static_cast<uint32_t>(NextTrueIndex), sizeof(MemArg),
2338+
Value});
23382339
break;
23392340
}
23402341
case kernel_param_kind_t::kind_std_layout: {
2342+
ur_exp_kernel_arg_type_t Type;
23412343
if (Arg.MPtr) {
2342-
Adapter.call<UrApiKind::urKernelSetArgValue>(
2343-
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2344+
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23442345
} else {
2345-
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2346-
Arg.MSize, nullptr);
2346+
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23472347
}
2348+
ur_exp_kernel_arg_value_t Value = {};
2349+
Value.value = {Arg.MPtr};
2350+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2351+
Type, static_cast<uint32_t>(NextTrueIndex),
2352+
static_cast<size_t>(Arg.MSize), Value});
23482353

23492354
break;
23502355
}
23512356
case kernel_param_kind_t::kind_sampler: {
23522357
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2353-
ur_sampler_handle_t Sampler =
2354-
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2355-
->getOrCreateSampler(ContextImpl);
2356-
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2357-
nullptr, Sampler);
2358+
ur_exp_kernel_arg_value_t Value = {};
2359+
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2360+
->getOrCreateSampler(ContextImpl);
2361+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2362+
UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2363+
static_cast<uint32_t>(NextTrueIndex),
2364+
sizeof(ur_sampler_handle_t), Value});
23582365
break;
23592366
}
23602367
case kernel_param_kind_t::kind_pointer: {
2361-
// We need to de-rerence this to get the actual USM allocation - that's the
2368+
ur_exp_kernel_arg_value_t Value = {};
2369+
// We need to de-rerence to get the actual USM allocation - that's the
23622370
// pointer UR is expecting.
2363-
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2364-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2365-
nullptr, Ptr);
2371+
Value.pointer = *static_cast<void *const *>(Arg.MPtr);
2372+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2373+
UR_EXP_KERNEL_ARG_TYPE_POINTER,
2374+
static_cast<uint32_t>(NextTrueIndex), sizeof(Arg.MPtr),
2375+
Value});
23662376
break;
23672377
}
23682378
case kernel_param_kind_t::kind_specialization_constants_buffer: {
23692379
assert(DeviceImageImpl != nullptr);
23702380
ur_mem_handle_t SpecConstsBuffer =
23712381
DeviceImageImpl->get_spec_const_buffer_ref();
2372-
2373-
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2374-
MemObjProps.pNext = nullptr;
2375-
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2376-
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2377-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2378-
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2382+
ur_exp_kernel_arg_value_t Value = {};
2383+
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2384+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2385+
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2386+
static_cast<uint32_t>(NextTrueIndex),
2387+
sizeof(SpecConstsBuffer), Value});
23792388
break;
23802389
}
23812390
case kernel_param_kind_t::kind_invalid:
@@ -2404,22 +2413,32 @@ static ur_result_t SetKernelParamsAndLaunch(
24042413
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty);
24052414
}
24062415

2416+
std::vector<ur_exp_kernel_arg_properties_t> UrArgs;
2417+
UrArgs.reserve(Args.size());
2418+
24072419
if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures) {
2408-
auto setFunc = [&Adapter, Kernel,
2420+
auto setFunc = [&UrArgs,
24092421
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24102422
size_t NextTrueIndex) {
24112423
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
24122424
switch (ParamDesc.kind) {
24132425
case kernel_param_kind_t::kind_std_layout: {
24142426
int Size = ParamDesc.info;
2415-
Adapter.call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2416-
Size, nullptr, ArgPtr);
2427+
ur_exp_kernel_arg_value_t Value = {};
2428+
Value.value = ArgPtr;
2429+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2430+
UR_EXP_KERNEL_ARG_TYPE_VALUE,
2431+
static_cast<uint32_t>(NextTrueIndex),
2432+
static_cast<size_t>(Size), Value});
24172433
break;
24182434
}
24192435
case kernel_param_kind_t::kind_pointer: {
2420-
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
2421-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2422-
nullptr, Ptr);
2436+
ur_exp_kernel_arg_value_t Value = {};
2437+
Value.pointer = *static_cast<const void *const *>(ArgPtr);
2438+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
2439+
UR_EXP_KERNEL_ARG_TYPE_POINTER,
2440+
static_cast<uint32_t>(NextTrueIndex),
2441+
sizeof(Value.pointer), Value});
24232442
break;
24242443
}
24252444
default:
@@ -2429,23 +2448,28 @@ static ur_result_t SetKernelParamsAndLaunch(
24292448
applyFuncOnFilteredArgs(EliminatedArgMask, DeviceKernelInfo.NumParams,
24302449
DeviceKernelInfo.ParamDescGetter, setFunc);
24312450
} else {
2432-
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2433-
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2434-
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2435-
Queue.getContextImpl(), Arg, NextTrueIndex);
2451+
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, &Queue,
2452+
&UrArgs](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2453+
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
2454+
Queue.getContextImpl(), Arg, NextTrueIndex, UrArgs);
24362455
};
24372456
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
24382457
}
24392458

2440-
const std::optional<int> &ImplicitLocalArg =
2441-
DeviceKernelInfo.getImplicitLocalArgPos();
2459+
std::optional<int> ImplicitLocalArg =
2460+
ProgramManager::getInstance().kernelImplicitLocalArgPos(
2461+
DeviceKernelInfo.Name);
24422462
// Set the implicit local memory buffer to support
24432463
// get_work_group_scratch_memory. This is for backend not supporting
24442464
// CUDA-style local memory setting. Note that we may have -1 as a position,
24452465
// this indicates the buffer is actually unused and was elided.
24462466
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
2447-
Adapter.call<UrApiKind::urKernelSetArgLocal>(
2448-
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
2467+
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2468+
nullptr,
2469+
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2470+
static_cast<uint32_t>(ImplicitLocalArg.value()),
2471+
WorkGroupMemorySize,
2472+
{nullptr}});
24492473
}
24502474

24512475
adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
@@ -2468,16 +2492,14 @@ static ur_result_t SetKernelParamsAndLaunch(
24682492
/* pPropSizeRet = */ nullptr);
24692493

24702494
const bool EnforcedLocalSize =
2471-
(RequiredWGSize[0] != 0 &&
2472-
(NDRDesc.Dims < 2 || RequiredWGSize[1] != 0) &&
2473-
(NDRDesc.Dims < 3 || RequiredWGSize[2] != 0));
2495+
(RequiredWGSize[0] != 0 || RequiredWGSize[1] != 0 ||
2496+
RequiredWGSize[2] != 0);
24742497
if (EnforcedLocalSize)
24752498
LocalSize = RequiredWGSize;
24762499
}
2477-
2478-
const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 &&
2479-
(NDRDesc.Dims < 2 || NDRDesc.GlobalOffset[1] != 0) &&
2480-
(NDRDesc.Dims < 3 || NDRDesc.GlobalOffset[2] != 0);
2500+
const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 ||
2501+
NDRDesc.GlobalOffset[1] != 0 ||
2502+
NDRDesc.GlobalOffset[2] != 0;
24812503

24822504
std::vector<ur_kernel_launch_property_t> property_list;
24832505

@@ -2505,20 +2527,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25052527
{{WorkGroupMemorySize}}});
25062528
}
25072529
ur_event_handle_t UREvent = nullptr;
2508-
ur_result_t Error = Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
2509-
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2510-
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
2511-
LocalSize, property_list.size(),
2512-
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
2513-
RawEvents.empty() ? nullptr : &RawEvents[0],
2514-
OutEventImpl ? &UREvent : nullptr);
2530+
ur_result_t Error =
2531+
Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2532+
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2533+
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
2534+
&NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(),
2535+
property_list.size(),
2536+
property_list.empty() ? nullptr : property_list.data(),
2537+
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
2538+
OutEventImpl ? &UREvent : nullptr);
25152539
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25162540
OutEventImpl->setHandle(UREvent);
25172541
}
25182542

25192543
return Error;
25202544
}
25212545

2546+
// Sets arguments for a given kernel and device based on the argument type.
2547+
// This is a legacy path which the graphs extension still uses.
2548+
static void SetArgBasedOnType(
2549+
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2550+
device_image_impl *DeviceImageImpl,
2551+
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2552+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2553+
switch (Arg.MType) {
2554+
case kernel_param_kind_t::kind_dynamic_work_group_memory:
2555+
break;
2556+
case kernel_param_kind_t::kind_work_group_memory:
2557+
break;
2558+
case kernel_param_kind_t::kind_stream:
2559+
break;
2560+
case kernel_param_kind_t::kind_dynamic_accessor:
2561+
case kernel_param_kind_t::kind_accessor: {
2562+
Requirement *Req = (Requirement *)(Arg.MPtr);
2563+
2564+
// getMemAllocationFunc is nullptr when there are no requirements. However,
2565+
// we may pass default constructed accessors to a command, which don't add
2566+
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2567+
// valid case, so we need to properly handle it.
2568+
ur_mem_handle_t MemArg =
2569+
getMemAllocationFunc
2570+
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
2571+
: nullptr;
2572+
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2573+
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2574+
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2575+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2576+
&MemObjData, MemArg);
2577+
break;
2578+
}
2579+
case kernel_param_kind_t::kind_std_layout: {
2580+
if (Arg.MPtr) {
2581+
Adapter.call<UrApiKind::urKernelSetArgValue>(
2582+
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2583+
} else {
2584+
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2585+
Arg.MSize, nullptr);
2586+
}
2587+
2588+
break;
2589+
}
2590+
case kernel_param_kind_t::kind_sampler: {
2591+
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2592+
ur_sampler_handle_t Sampler =
2593+
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2594+
->getOrCreateSampler(ContextImpl);
2595+
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2596+
nullptr, Sampler);
2597+
break;
2598+
}
2599+
case kernel_param_kind_t::kind_pointer: {
2600+
// We need to de-rerence this to get the actual USM allocation - that's the
2601+
// pointer UR is expecting.
2602+
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2603+
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2604+
nullptr, Ptr);
2605+
break;
2606+
}
2607+
case kernel_param_kind_t::kind_specialization_constants_buffer: {
2608+
assert(DeviceImageImpl != nullptr);
2609+
ur_mem_handle_t SpecConstsBuffer =
2610+
DeviceImageImpl->get_spec_const_buffer_ref();
2611+
2612+
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2613+
MemObjProps.pNext = nullptr;
2614+
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2615+
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2616+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2617+
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2618+
break;
2619+
}
2620+
case kernel_param_kind_t::kind_invalid:
2621+
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
2622+
"Invalid kernel param kind " +
2623+
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
2624+
break;
2625+
}
2626+
}
2627+
25222628
static std::tuple<ur_kernel_handle_t, device_image_impl *,
25232629
const KernelArgMask *>
25242630
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,

sycl/test-e2e/Adapters/level_zero/batch_barrier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int main(int argc, char *argv[]) {
2424
queue q;
2525

2626
submit_kernel(q); // starts a batch
27-
// CHECK: ---> urEnqueueKernelLaunch
27+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
2828
// CHECK-NOT: zeCommandQueueExecuteCommandLists
2929

3030
// Initializing Level Zero driver is required if this test is linked
@@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {
4242
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4343

4444
submit_kernel(q);
45-
// CHECK: ---> urEnqueueKernelLaunch
45+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
4646
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4747

4848
// interop should close the batch

0 commit comments

Comments
 (0)