Skip to content

Commit 84879b6

Browse files
committed
Revert "Revert "[UR][SYCL] Introduce UR api to set kernel args + launch in one call." (intel#19661)"
This reverts commit d25d6d6. Signed-off-by: Lukasz Dorau <lukasz.dorau@intel.com>
1 parent df94cb3 commit 84879b6

File tree

85 files changed

+3382
-433
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+3382
-433
lines changed

sycl/source/detail/queue_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
716716
}
717717
#endif
718718

719+
std::vector<ur_exp_kernel_arg_properties_t> UrArgs;
720+
719721
protected:
720722
template <typename HandlerType = handler>
721723
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {

sycl/source/detail/scheduler/commands.cpp

Lines changed: 174 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,14 +2303,22 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303
}
23042304
}
23052305

2306-
// Sets arguments for a given kernel and device based on the argument type.
2307-
// Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2308-
// extension.
2309-
static void SetArgBasedOnType(
2310-
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2306+
// Gets UR argument struct for a given kernel and device based on the argument
2307+
// type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2308+
// the graphs extension (LaunchWithArgs for graphs is planned future work).
2309+
static void GetUrArgsBasedOnType(
23112310
device_image_impl *DeviceImageImpl,
23122311
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2313-
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2312+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2313+
std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2314+
// UrArg.size == 0 indicates uninitialized structure
2315+
ur_exp_kernel_arg_properties_t UrArg = {
2316+
UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2317+
nullptr,
2318+
UR_EXP_KERNEL_ARG_TYPE_VALUE,
2319+
static_cast<uint32_t>(NextTrueIndex),
2320+
0,
2321+
{}};
23142322
switch (Arg.MType) {
23152323
case kernel_param_kind_t::kind_dynamic_work_group_memory:
23162324
break;
@@ -2330,52 +2338,56 @@ static void SetArgBasedOnType(
23302338
getMemAllocationFunc
23312339
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
23322340
: nullptr;
2333-
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2334-
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2335-
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2336-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2337-
&MemObjData, MemArg);
2341+
ur_exp_kernel_arg_value_t Value = {};
2342+
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
2343+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ;
2344+
UrArg.size = sizeof(MemArg);
2345+
UrArg.value = Value;
23382346
break;
23392347
}
23402348
case kernel_param_kind_t::kind_std_layout: {
2349+
ur_exp_kernel_arg_type_t Type;
23412350
if (Arg.MPtr) {
2342-
Adapter.call<UrApiKind::urKernelSetArgValue>(
2343-
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2351+
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23442352
} else {
2345-
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2346-
Arg.MSize, nullptr);
2353+
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23472354
}
2348-
2355+
ur_exp_kernel_arg_value_t Value = {};
2356+
Value.value = {Arg.MPtr};
2357+
UrArg.type = Type;
2358+
UrArg.size = static_cast<size_t>(Arg.MSize);
2359+
UrArg.value = Value;
23492360
break;
23502361
}
23512362
case kernel_param_kind_t::kind_sampler: {
23522363
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2353-
ur_sampler_handle_t Sampler =
2354-
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2355-
->getOrCreateSampler(ContextImpl);
2356-
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2357-
nullptr, Sampler);
2364+
ur_exp_kernel_arg_value_t Value = {};
2365+
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2366+
->getOrCreateSampler(ContextImpl);
2367+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_SAMPLER;
2368+
UrArg.size = sizeof(ur_sampler_handle_t);
2369+
UrArg.value = Value;
23582370
break;
23592371
}
23602372
case kernel_param_kind_t::kind_pointer: {
2361-
// We need to de-rerence this to get the actual USM allocation - that's the
2373+
ur_exp_kernel_arg_value_t Value = {};
2374+
// We need to de-rerence to get the actual USM allocation - that's the
23622375
// pointer UR is expecting.
2363-
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2364-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2365-
nullptr, Ptr);
2376+
Value.pointer = *static_cast<void *const *>(Arg.MPtr);
2377+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_POINTER;
2378+
UrArg.size = sizeof(Arg.MPtr);
2379+
UrArg.value = Value;
23662380
break;
23672381
}
23682382
case kernel_param_kind_t::kind_specialization_constants_buffer: {
23692383
assert(DeviceImageImpl != nullptr);
23702384
ur_mem_handle_t SpecConstsBuffer =
23712385
DeviceImageImpl->get_spec_const_buffer_ref();
2372-
2373-
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2374-
MemObjProps.pNext = nullptr;
2375-
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2376-
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2377-
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2378-
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2386+
ur_exp_kernel_arg_value_t Value = {};
2387+
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2388+
UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ;
2389+
UrArg.size = sizeof(SpecConstsBuffer);
2390+
UrArg.value = Value;
23792391
break;
23802392
}
23812393
case kernel_param_kind_t::kind_invalid:
@@ -2384,6 +2396,10 @@ static void SetArgBasedOnType(
23842396
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
23852397
break;
23862398
}
2399+
2400+
if (UrArg.size) {
2401+
UrArgs.push_back(UrArg);
2402+
}
23872403
}
23882404

23892405
static ur_result_t SetKernelParamsAndLaunch(
@@ -2404,22 +2420,33 @@ static ur_result_t SetKernelParamsAndLaunch(
24042420
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty);
24052421
}
24062422

2423+
// just a performance optimization - avoid heap allocations
2424+
Queue.UrArgs.reserve(Args.size());
2425+
Queue.UrArgs.clear();
2426+
24072427
if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures) {
2408-
auto setFunc = [&Adapter, Kernel,
2428+
auto setFunc = [&Queue,
24092429
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24102430
size_t NextTrueIndex) {
24112431
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
24122432
switch (ParamDesc.kind) {
24132433
case kernel_param_kind_t::kind_std_layout: {
24142434
int Size = ParamDesc.info;
2415-
Adapter.call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2416-
Size, nullptr, ArgPtr);
2435+
ur_exp_kernel_arg_value_t Value = {};
2436+
Value.value = ArgPtr;
2437+
Queue.UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2438+
nullptr, UR_EXP_KERNEL_ARG_TYPE_VALUE,
2439+
static_cast<uint32_t>(NextTrueIndex),
2440+
static_cast<size_t>(Size), Value});
24172441
break;
24182442
}
24192443
case kernel_param_kind_t::kind_pointer: {
2420-
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
2421-
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2422-
nullptr, Ptr);
2444+
ur_exp_kernel_arg_value_t Value = {};
2445+
Value.pointer = *static_cast<const void *const *>(ArgPtr);
2446+
Queue.UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2447+
nullptr, UR_EXP_KERNEL_ARG_TYPE_POINTER,
2448+
static_cast<uint32_t>(NextTrueIndex),
2449+
sizeof(Value.pointer), Value});
24232450
break;
24242451
}
24252452
default:
@@ -2429,23 +2456,29 @@ static ur_result_t SetKernelParamsAndLaunch(
24292456
applyFuncOnFilteredArgs(EliminatedArgMask, DeviceKernelInfo.NumParams,
24302457
DeviceKernelInfo.ParamDescGetter, setFunc);
24312458
} else {
2432-
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2459+
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc,
24332460
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2434-
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2435-
Queue.getContextImpl(), Arg, NextTrueIndex);
2461+
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
2462+
Queue.getContextImpl(), Arg, NextTrueIndex,
2463+
Queue.UrArgs);
24362464
};
24372465
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
24382466
}
24392467

2440-
const std::optional<int> &ImplicitLocalArg =
2441-
DeviceKernelInfo.getImplicitLocalArgPos();
2468+
std::optional<int> ImplicitLocalArg =
2469+
ProgramManager::getInstance().kernelImplicitLocalArgPos(
2470+
DeviceKernelInfo.Name);
24422471
// Set the implicit local memory buffer to support
24432472
// get_work_group_scratch_memory. This is for backend not supporting
24442473
// CUDA-style local memory setting. Note that we may have -1 as a position,
24452474
// this indicates the buffer is actually unused and was elided.
24462475
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
2447-
Adapter.call<UrApiKind::urKernelSetArgLocal>(
2448-
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
2476+
Queue.UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2477+
nullptr,
2478+
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2479+
static_cast<uint32_t>(ImplicitLocalArg.value()),
2480+
WorkGroupMemorySize,
2481+
{nullptr}});
24492482
}
24502483

24512484
adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
@@ -2468,16 +2501,14 @@ static ur_result_t SetKernelParamsAndLaunch(
24682501
/* pPropSizeRet = */ nullptr);
24692502

24702503
const bool EnforcedLocalSize =
2471-
(RequiredWGSize[0] != 0 &&
2472-
(NDRDesc.Dims < 2 || RequiredWGSize[1] != 0) &&
2473-
(NDRDesc.Dims < 3 || RequiredWGSize[2] != 0));
2504+
(RequiredWGSize[0] != 0 || RequiredWGSize[1] != 0 ||
2505+
RequiredWGSize[2] != 0);
24742506
if (EnforcedLocalSize)
24752507
LocalSize = RequiredWGSize;
24762508
}
2477-
2478-
const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 &&
2479-
(NDRDesc.Dims < 2 || NDRDesc.GlobalOffset[1] != 0) &&
2480-
(NDRDesc.Dims < 3 || NDRDesc.GlobalOffset[2] != 0);
2509+
const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 ||
2510+
NDRDesc.GlobalOffset[1] != 0 ||
2511+
NDRDesc.GlobalOffset[2] != 0;
24812512

24822513
std::vector<ur_kernel_launch_property_t> property_list;
24832514

@@ -2505,20 +2536,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25052536
{{WorkGroupMemorySize}}});
25062537
}
25072538
ur_event_handle_t UREvent = nullptr;
2508-
ur_result_t Error = Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
2509-
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2510-
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
2511-
LocalSize, property_list.size(),
2512-
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
2513-
RawEvents.empty() ? nullptr : &RawEvents[0],
2514-
OutEventImpl ? &UREvent : nullptr);
2539+
ur_result_t Error =
2540+
Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2541+
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
2542+
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
2543+
&NDRDesc.GlobalSize[0], LocalSize, Queue.UrArgs.size(),
2544+
Queue.UrArgs.data(), property_list.size(),
2545+
property_list.empty() ? nullptr : property_list.data(),
2546+
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
2547+
OutEventImpl ? &UREvent : nullptr);
25152548
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25162549
OutEventImpl->setHandle(UREvent);
25172550
}
25182551

25192552
return Error;
25202553
}
25212554

2555+
// Sets arguments for a given kernel and device based on the argument type.
2556+
// This is a legacy path which the graphs extension still uses.
2557+
static void SetArgBasedOnType(
2558+
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2559+
device_image_impl *DeviceImageImpl,
2560+
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2561+
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2562+
switch (Arg.MType) {
2563+
case kernel_param_kind_t::kind_dynamic_work_group_memory:
2564+
break;
2565+
case kernel_param_kind_t::kind_work_group_memory:
2566+
break;
2567+
case kernel_param_kind_t::kind_stream:
2568+
break;
2569+
case kernel_param_kind_t::kind_dynamic_accessor:
2570+
case kernel_param_kind_t::kind_accessor: {
2571+
Requirement *Req = (Requirement *)(Arg.MPtr);
2572+
2573+
// getMemAllocationFunc is nullptr when there are no requirements. However,
2574+
// we may pass default constructed accessors to a command, which don't add
2575+
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2576+
// valid case, so we need to properly handle it.
2577+
ur_mem_handle_t MemArg =
2578+
getMemAllocationFunc
2579+
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
2580+
: nullptr;
2581+
ur_kernel_arg_mem_obj_properties_t MemObjData{};
2582+
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2583+
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
2584+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2585+
&MemObjData, MemArg);
2586+
break;
2587+
}
2588+
case kernel_param_kind_t::kind_std_layout: {
2589+
if (Arg.MPtr) {
2590+
Adapter.call<UrApiKind::urKernelSetArgValue>(
2591+
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
2592+
} else {
2593+
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2594+
Arg.MSize, nullptr);
2595+
}
2596+
2597+
break;
2598+
}
2599+
case kernel_param_kind_t::kind_sampler: {
2600+
sampler *SamplerPtr = (sampler *)Arg.MPtr;
2601+
ur_sampler_handle_t Sampler =
2602+
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
2603+
->getOrCreateSampler(ContextImpl);
2604+
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2605+
nullptr, Sampler);
2606+
break;
2607+
}
2608+
case kernel_param_kind_t::kind_pointer: {
2609+
// We need to de-rerence this to get the actual USM allocation - that's the
2610+
// pointer UR is expecting.
2611+
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
2612+
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2613+
nullptr, Ptr);
2614+
break;
2615+
}
2616+
case kernel_param_kind_t::kind_specialization_constants_buffer: {
2617+
assert(DeviceImageImpl != nullptr);
2618+
ur_mem_handle_t SpecConstsBuffer =
2619+
DeviceImageImpl->get_spec_const_buffer_ref();
2620+
2621+
ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2622+
MemObjProps.pNext = nullptr;
2623+
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2624+
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2625+
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
2626+
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2627+
break;
2628+
}
2629+
case kernel_param_kind_t::kind_invalid:
2630+
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
2631+
"Invalid kernel param kind " +
2632+
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
2633+
break;
2634+
}
2635+
}
2636+
25222637
static std::tuple<ur_kernel_handle_t, device_image_impl *,
25232638
const KernelArgMask *>
25242639
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,

sycl/test-e2e/Adapters/level_zero/batch_barrier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int main(int argc, char *argv[]) {
2424
queue q;
2525

2626
submit_kernel(q); // starts a batch
27-
// CHECK: ---> urEnqueueKernelLaunch
27+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
2828
// CHECK-NOT: zeCommandQueueExecuteCommandLists
2929

3030
// Initializing Level Zero driver is required if this test is linked
@@ -42,7 +42,7 @@ int main(int argc, char *argv[]) {
4242
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4343

4444
submit_kernel(q);
45-
// CHECK: ---> urEnqueueKernelLaunch
45+
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
4646
// CHECK-NOT: zeCommandQueueExecuteCommandLists
4747

4848
// interop should close the batch

0 commit comments

Comments
 (0)