@@ -2315,14 +2315,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23152315 }
23162316}
23172317
2318- // Sets arguments for a given kernel and device based on the argument type.
2319- // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2320- // extension.
2321- static void SetArgBasedOnType (
2322- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2318+ // Gets UR argument struct for a given kernel and device based on the argument
2319+ // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2320+ // the graphs extension (LaunchWithArgs for graphs is planned future work).
2321+ static void GetUrArgsBasedOnType (
23232322 device_image_impl *DeviceImageImpl,
23242323 const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2325- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2324+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2325+ std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
23262326 switch (Arg.MType ) {
23272327 case kernel_param_kind_t ::kind_dynamic_work_group_memory:
23282328 break ;
@@ -2342,52 +2342,61 @@ static void SetArgBasedOnType(
23422342 getMemAllocationFunc
23432343 ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
23442344 : nullptr ;
2345- ur_kernel_arg_mem_obj_properties_t MemObjData{};
2346- MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2347- MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2348- Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2349- &MemObjData, MemArg);
2345+ ur_exp_kernel_arg_value_t Value = {};
2346+ Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2347+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2348+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2349+ static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2350+ Value});
23502351 break ;
23512352 }
23522353 case kernel_param_kind_t ::kind_std_layout: {
2354+ ur_exp_kernel_arg_type_t Type;
23532355 if (Arg.MPtr ) {
2354- Adapter.call <UrApiKind::urKernelSetArgValue>(
2355- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2356+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23562357 } else {
2357- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2358- Arg.MSize , nullptr );
2358+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23592359 }
2360+ ur_exp_kernel_arg_value_t Value = {};
2361+ Value.value = {Arg.MPtr };
2362+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2363+ Type, static_cast <uint32_t >(NextTrueIndex),
2364+ static_cast <size_t >(Arg.MSize ), Value});
23602365
23612366 break ;
23622367 }
23632368 case kernel_param_kind_t ::kind_sampler: {
23642369 sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2365- ur_sampler_handle_t Sampler =
2366- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2367- ->getOrCreateSampler (ContextImpl);
2368- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2369- nullptr , Sampler);
2370+ ur_exp_kernel_arg_value_t Value = {};
2371+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2372+ ->getOrCreateSampler (ContextImpl);
2373+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2374+ UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2375+ static_cast <uint32_t >(NextTrueIndex),
2376+ sizeof (ur_sampler_handle_t ), Value});
23702377 break ;
23712378 }
23722379 case kernel_param_kind_t ::kind_pointer: {
2373- // We need to de-rerence this to get the actual USM allocation - that's the
2380+ ur_exp_kernel_arg_value_t Value = {};
2381+ // We need to de-rerence to get the actual USM allocation - that's the
23742382 // pointer UR is expecting.
2375- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2376- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2377- nullptr , Ptr);
2383+ Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2384+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2385+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2386+ static_cast <uint32_t >(NextTrueIndex), sizeof (Arg.MPtr ),
2387+ Value});
23782388 break ;
23792389 }
23802390 case kernel_param_kind_t ::kind_specialization_constants_buffer: {
23812391 assert (DeviceImageImpl != nullptr );
23822392 ur_mem_handle_t SpecConstsBuffer =
23832393 DeviceImageImpl->get_spec_const_buffer_ref ();
2384-
2385- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2386- MemObjProps.pNext = nullptr ;
2387- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2388- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2389- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2390- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2394+ ur_exp_kernel_arg_value_t Value = {};
2395+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2396+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2397+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2398+ static_cast <uint32_t >(NextTrueIndex),
2399+ sizeof (SpecConstsBuffer), Value});
23912400 break ;
23922401 }
23932402 case kernel_param_kind_t ::kind_invalid:
@@ -2420,22 +2429,32 @@ static ur_result_t SetKernelParamsAndLaunch(
24202429 DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
24212430 }
24222431
2432+ std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2433+ UrArgs.reserve (Args.size ());
2434+
24232435 if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2424- auto setFunc = [&Adapter, Kernel ,
2436+ auto setFunc = [&UrArgs ,
24252437 KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24262438 size_t NextTrueIndex) {
24272439 const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
24282440 switch (ParamDesc.kind ) {
24292441 case kernel_param_kind_t ::kind_std_layout: {
24302442 int Size = ParamDesc.info ;
2431- Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2432- Size, nullptr , ArgPtr);
2443+ ur_exp_kernel_arg_value_t Value = {};
2444+ Value.value = ArgPtr;
2445+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2446+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2447+ static_cast <uint32_t >(NextTrueIndex),
2448+ static_cast <size_t >(Size), Value});
24332449 break ;
24342450 }
24352451 case kernel_param_kind_t ::kind_pointer: {
2436- const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2437- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2438- nullptr , Ptr);
2452+ ur_exp_kernel_arg_value_t Value = {};
2453+ Value.pointer = *static_cast <const void *const *>(ArgPtr);
2454+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2455+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2456+ static_cast <uint32_t >(NextTrueIndex),
2457+ sizeof (Value.pointer ), Value});
24392458 break ;
24402459 }
24412460 default :
@@ -2445,10 +2464,10 @@ static ur_result_t SetKernelParamsAndLaunch(
24452464 applyFuncOnFilteredArgs (EliminatedArgMask, KernelNumArgs,
24462465 KernelParamDescGetter, setFunc);
24472466 } else {
2448- auto setFunc = [&Adapter, Kernel, &DeviceImageImpl , &getMemAllocationFunc ,
2449- &Queue ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2450- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2451- Queue.getContextImpl (), Arg, NextTrueIndex);
2467+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc , &Queue ,
2468+ &UrArgs ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2469+ GetUrArgsBasedOnType ( DeviceImageImpl, getMemAllocationFunc,
2470+ Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs );
24522471 };
24532472 applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
24542473 }
@@ -2461,8 +2480,12 @@ static ur_result_t SetKernelParamsAndLaunch(
24612480 // CUDA-style local memory setting. Note that we may have -1 as a position,
24622481 // this indicates the buffer is actually unused and was elided.
24632482 if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2464- Adapter.call <UrApiKind::urKernelSetArgLocal>(
2465- Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2483+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2484+ nullptr ,
2485+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2486+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2487+ WorkGroupMemorySize,
2488+ {nullptr }});
24662489 }
24672490
24682491 adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2520,20 +2543,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25202543 {{WorkGroupMemorySize}}});
25212544 }
25222545 ur_event_handle_t UREvent = nullptr ;
2523- ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2524- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2525- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2526- LocalSize, property_list.size (),
2527- property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2528- RawEvents.empty () ? nullptr : &RawEvents[0 ],
2529- OutEventImpl ? &UREvent : nullptr );
2546+ ur_result_t Error =
2547+ Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2548+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2549+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2550+ &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2551+ property_list.size (),
2552+ property_list.empty () ? nullptr : property_list.data (),
2553+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2554+ OutEventImpl ? &UREvent : nullptr );
25302555 if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25312556 OutEventImpl->setHandle (UREvent);
25322557 }
25332558
25342559 return Error;
25352560}
25362561
2562+ // Sets arguments for a given kernel and device based on the argument type.
2563+ // This is a legacy path which the graphs extension still uses.
2564+ static void SetArgBasedOnType (
2565+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2566+ device_image_impl *DeviceImageImpl,
2567+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2568+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2569+ switch (Arg.MType ) {
2570+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2571+ break ;
2572+ case kernel_param_kind_t ::kind_work_group_memory:
2573+ break ;
2574+ case kernel_param_kind_t ::kind_stream:
2575+ break ;
2576+ case kernel_param_kind_t ::kind_dynamic_accessor:
2577+ case kernel_param_kind_t ::kind_accessor: {
2578+ Requirement *Req = (Requirement *)(Arg.MPtr );
2579+
2580+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2581+ // we may pass default constructed accessors to a command, which don't add
2582+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2583+ // valid case, so we need to properly handle it.
2584+ ur_mem_handle_t MemArg =
2585+ getMemAllocationFunc
2586+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2587+ : nullptr ;
2588+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2589+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2590+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2591+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2592+ &MemObjData, MemArg);
2593+ break ;
2594+ }
2595+ case kernel_param_kind_t ::kind_std_layout: {
2596+ if (Arg.MPtr ) {
2597+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2598+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2599+ } else {
2600+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2601+ Arg.MSize , nullptr );
2602+ }
2603+
2604+ break ;
2605+ }
2606+ case kernel_param_kind_t ::kind_sampler: {
2607+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2608+ ur_sampler_handle_t Sampler =
2609+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2610+ ->getOrCreateSampler (ContextImpl);
2611+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2612+ nullptr , Sampler);
2613+ break ;
2614+ }
2615+ case kernel_param_kind_t ::kind_pointer: {
2616+ // We need to de-rerence this to get the actual USM allocation - that's the
2617+ // pointer UR is expecting.
2618+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2619+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2620+ nullptr , Ptr);
2621+ break ;
2622+ }
2623+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2624+ assert (DeviceImageImpl != nullptr );
2625+ ur_mem_handle_t SpecConstsBuffer =
2626+ DeviceImageImpl->get_spec_const_buffer_ref ();
2627+
2628+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2629+ MemObjProps.pNext = nullptr ;
2630+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2631+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2632+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2633+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2634+ break ;
2635+ }
2636+ case kernel_param_kind_t ::kind_invalid:
2637+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2638+ " Invalid kernel param kind " +
2639+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2640+ break ;
2641+ }
2642+ }
2643+
25372644static std::tuple<ur_kernel_handle_t , device_image_impl *,
25382645 const KernelArgMask *>
25392646getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments