@@ -2303,14 +2303,22 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303 }
23042304}
23052305
2306- // Sets arguments for a given kernel and device based on the argument type.
2307- // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2308- // extension.
2309- static void SetArgBasedOnType (
2310- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2306+ // Gets UR argument struct for a given kernel and device based on the argument
2307+ // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2308+ // the graphs extension (LaunchWithArgs for graphs is planned future work).
2309+ static void GetUrArgsBasedOnType (
23112310 device_image_impl *DeviceImageImpl,
23122311 const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2313- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2312+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2313+ std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2314+ // UrArg.size == 0 indicates uninitialized structure
2315+ ur_exp_kernel_arg_properties_t UrArg = {
2316+ UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2317+ nullptr ,
2318+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2319+ static_cast <uint32_t >(NextTrueIndex),
2320+ 0 ,
2321+ {}};
23142322 switch (Arg.MType ) {
23152323 case kernel_param_kind_t ::kind_dynamic_work_group_memory:
23162324 break ;
@@ -2330,52 +2338,56 @@ static void SetArgBasedOnType(
23302338 getMemAllocationFunc
23312339 ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
23322340 : nullptr ;
2333- ur_kernel_arg_mem_obj_properties_t MemObjData {};
2334- MemObjData. stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES ;
2335- MemObjData. memoryAccess = AccessModeToUr (Req-> MAccessMode ) ;
2336- Adapter. call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2337- &MemObjData, MemArg) ;
2341+ ur_exp_kernel_arg_value_t Value = {};
2342+ Value. memObjTuple = {MemArg, AccessModeToUr (Req-> MAccessMode )} ;
2343+ UrArg. type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ ;
2344+ UrArg. size = sizeof (MemArg);
2345+ UrArg. value = Value ;
23382346 break ;
23392347 }
23402348 case kernel_param_kind_t ::kind_std_layout: {
2349+ ur_exp_kernel_arg_type_t Type;
23412350 if (Arg.MPtr ) {
2342- Adapter.call <UrApiKind::urKernelSetArgValue>(
2343- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2351+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23442352 } else {
2345- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2346- Arg.MSize , nullptr );
2353+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23472354 }
2348-
2355+ ur_exp_kernel_arg_value_t Value = {};
2356+ Value.value = {Arg.MPtr };
2357+ UrArg.type = Type;
2358+ UrArg.size = static_cast <size_t >(Arg.MSize );
2359+ UrArg.value = Value;
23492360 break ;
23502361 }
23512362 case kernel_param_kind_t ::kind_sampler: {
23522363 sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2353- ur_sampler_handle_t Sampler =
2354- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2355- ->getOrCreateSampler (ContextImpl);
2356- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2357- nullptr , Sampler);
2364+ ur_exp_kernel_arg_value_t Value = {};
2365+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2366+ ->getOrCreateSampler (ContextImpl);
2367+ UrArg.type = UR_EXP_KERNEL_ARG_TYPE_SAMPLER;
2368+ UrArg.size = sizeof (ur_sampler_handle_t );
2369+ UrArg.value = Value;
23582370 break ;
23592371 }
23602372 case kernel_param_kind_t ::kind_pointer: {
2361- // We need to de-rerence this to get the actual USM allocation - that's the
2373+ ur_exp_kernel_arg_value_t Value = {};
2374+ // We need to de-rerence to get the actual USM allocation - that's the
23622375 // pointer UR is expecting.
2363- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2364- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2365- nullptr , Ptr);
2376+ Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2377+ UrArg.type = UR_EXP_KERNEL_ARG_TYPE_POINTER;
2378+ UrArg.size = sizeof (Arg.MPtr );
2379+ UrArg.value = Value;
23662380 break ;
23672381 }
23682382 case kernel_param_kind_t ::kind_specialization_constants_buffer: {
23692383 assert (DeviceImageImpl != nullptr );
23702384 ur_mem_handle_t SpecConstsBuffer =
23712385 DeviceImageImpl->get_spec_const_buffer_ref ();
2372-
2373- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2374- MemObjProps.pNext = nullptr ;
2375- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2376- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2377- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2378- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2386+ ur_exp_kernel_arg_value_t Value = {};
2387+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2388+ UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ;
2389+ UrArg.size = sizeof (SpecConstsBuffer);
2390+ UrArg.value = Value;
23792391 break ;
23802392 }
23812393 case kernel_param_kind_t ::kind_invalid:
@@ -2384,6 +2396,10 @@ static void SetArgBasedOnType(
23842396 codeToString (UR_RESULT_ERROR_INVALID_VALUE));
23852397 break ;
23862398 }
2399+
2400+ if (UrArg.size ) {
2401+ UrArgs.push_back (UrArg);
2402+ }
23872403}
23882404
23892405static ur_result_t SetKernelParamsAndLaunch (
@@ -2404,22 +2420,33 @@ static ur_result_t SetKernelParamsAndLaunch(
24042420 DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
24052421 }
24062422
2423+ // just a performance optimization - avoid heap allocations
2424+ Queue.UrArgs .reserve (Args.size ());
2425+ Queue.UrArgs .clear ();
2426+
24072427 if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures ) {
2408- auto setFunc = [&Adapter, Kernel ,
2428+ auto setFunc = [&Queue ,
24092429 KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24102430 size_t NextTrueIndex) {
24112431 const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
24122432 switch (ParamDesc.kind ) {
24132433 case kernel_param_kind_t ::kind_std_layout: {
24142434 int Size = ParamDesc.info ;
2415- Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2416- Size, nullptr , ArgPtr);
2435+ ur_exp_kernel_arg_value_t Value = {};
2436+ Value.value = ArgPtr;
2437+ Queue.UrArgs .push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2438+ nullptr , UR_EXP_KERNEL_ARG_TYPE_VALUE,
2439+ static_cast <uint32_t >(NextTrueIndex),
2440+ static_cast <size_t >(Size), Value});
24172441 break ;
24182442 }
24192443 case kernel_param_kind_t ::kind_pointer: {
2420- const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2421- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2422- nullptr , Ptr);
2444+ ur_exp_kernel_arg_value_t Value = {};
2445+ Value.pointer = *static_cast <const void *const *>(ArgPtr);
2446+ Queue.UrArgs .push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2447+ nullptr , UR_EXP_KERNEL_ARG_TYPE_POINTER,
2448+ static_cast <uint32_t >(NextTrueIndex),
2449+ sizeof (Value.pointer ), Value});
24232450 break ;
24242451 }
24252452 default :
@@ -2429,23 +2456,29 @@ static ur_result_t SetKernelParamsAndLaunch(
24292456 applyFuncOnFilteredArgs (EliminatedArgMask, DeviceKernelInfo.NumParams ,
24302457 DeviceKernelInfo.ParamDescGetter , setFunc);
24312458 } else {
2432- auto setFunc = [&Adapter, Kernel, & DeviceImageImpl, &getMemAllocationFunc,
2459+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc,
24332460 &Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2434- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2435- Queue.getContextImpl (), Arg, NextTrueIndex);
2461+ GetUrArgsBasedOnType (DeviceImageImpl, getMemAllocationFunc,
2462+ Queue.getContextImpl (), Arg, NextTrueIndex,
2463+ Queue.UrArgs );
24362464 };
24372465 applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
24382466 }
24392467
2440- const std::optional<int > &ImplicitLocalArg =
2441- DeviceKernelInfo.getImplicitLocalArgPos ();
2468+ std::optional<int > ImplicitLocalArg =
2469+ ProgramManager::getInstance ().kernelImplicitLocalArgPos (
2470+ DeviceKernelInfo.Name );
24422471 // Set the implicit local memory buffer to support
24432472 // get_work_group_scratch_memory. This is for backend not supporting
24442473 // CUDA-style local memory setting. Note that we may have -1 as a position,
24452474 // this indicates the buffer is actually unused and was elided.
24462475 if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2447- Adapter.call <UrApiKind::urKernelSetArgLocal>(
2448- Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2476+ Queue.UrArgs .push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2477+ nullptr ,
2478+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2479+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2480+ WorkGroupMemorySize,
2481+ {nullptr }});
24492482 }
24502483
24512484 adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2468,16 +2501,14 @@ static ur_result_t SetKernelParamsAndLaunch(
24682501 /* pPropSizeRet = */ nullptr );
24692502
24702503 const bool EnforcedLocalSize =
2471- (RequiredWGSize[0 ] != 0 &&
2472- (NDRDesc.Dims < 2 || RequiredWGSize[1 ] != 0 ) &&
2473- (NDRDesc.Dims < 3 || RequiredWGSize[2 ] != 0 ));
2504+ (RequiredWGSize[0 ] != 0 || RequiredWGSize[1 ] != 0 ||
2505+ RequiredWGSize[2 ] != 0 );
24742506 if (EnforcedLocalSize)
24752507 LocalSize = RequiredWGSize;
24762508 }
2477-
2478- const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 &&
2479- (NDRDesc.Dims < 2 || NDRDesc.GlobalOffset [1 ] != 0 ) &&
2480- (NDRDesc.Dims < 3 || NDRDesc.GlobalOffset [2 ] != 0 );
2509+ const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 ||
2510+ NDRDesc.GlobalOffset [1 ] != 0 ||
2511+ NDRDesc.GlobalOffset [2 ] != 0 ;
24812512
24822513 std::vector<ur_kernel_launch_property_t > property_list;
24832514
@@ -2505,20 +2536,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25052536 {{WorkGroupMemorySize}}});
25062537 }
25072538 ur_event_handle_t UREvent = nullptr ;
2508- ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2509- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2510- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2511- LocalSize, property_list.size (),
2512- property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2513- RawEvents.empty () ? nullptr : &RawEvents[0 ],
2514- OutEventImpl ? &UREvent : nullptr );
2539+ ur_result_t Error =
2540+ Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2541+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2542+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2543+ &NDRDesc.GlobalSize [0 ], LocalSize, Queue.UrArgs .size (),
2544+ Queue.UrArgs .data (), property_list.size (),
2545+ property_list.empty () ? nullptr : property_list.data (),
2546+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2547+ OutEventImpl ? &UREvent : nullptr );
25152548 if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25162549 OutEventImpl->setHandle (UREvent);
25172550 }
25182551
25192552 return Error;
25202553}
25212554
2555+ // Sets arguments for a given kernel and device based on the argument type.
2556+ // This is a legacy path which the graphs extension still uses.
2557+ static void SetArgBasedOnType (
2558+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2559+ device_image_impl *DeviceImageImpl,
2560+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2561+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2562+ switch (Arg.MType ) {
2563+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2564+ break ;
2565+ case kernel_param_kind_t ::kind_work_group_memory:
2566+ break ;
2567+ case kernel_param_kind_t ::kind_stream:
2568+ break ;
2569+ case kernel_param_kind_t ::kind_dynamic_accessor:
2570+ case kernel_param_kind_t ::kind_accessor: {
2571+ Requirement *Req = (Requirement *)(Arg.MPtr );
2572+
2573+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2574+ // we may pass default constructed accessors to a command, which don't add
2575+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2576+ // valid case, so we need to properly handle it.
2577+ ur_mem_handle_t MemArg =
2578+ getMemAllocationFunc
2579+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2580+ : nullptr ;
2581+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2582+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2583+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2584+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2585+ &MemObjData, MemArg);
2586+ break ;
2587+ }
2588+ case kernel_param_kind_t ::kind_std_layout: {
2589+ if (Arg.MPtr ) {
2590+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2591+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2592+ } else {
2593+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2594+ Arg.MSize , nullptr );
2595+ }
2596+
2597+ break ;
2598+ }
2599+ case kernel_param_kind_t ::kind_sampler: {
2600+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2601+ ur_sampler_handle_t Sampler =
2602+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2603+ ->getOrCreateSampler (ContextImpl);
2604+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2605+ nullptr , Sampler);
2606+ break ;
2607+ }
2608+ case kernel_param_kind_t ::kind_pointer: {
2609+ // We need to de-rerence this to get the actual USM allocation - that's the
2610+ // pointer UR is expecting.
2611+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2612+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2613+ nullptr , Ptr);
2614+ break ;
2615+ }
2616+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2617+ assert (DeviceImageImpl != nullptr );
2618+ ur_mem_handle_t SpecConstsBuffer =
2619+ DeviceImageImpl->get_spec_const_buffer_ref ();
2620+
2621+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2622+ MemObjProps.pNext = nullptr ;
2623+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2624+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2625+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2626+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2627+ break ;
2628+ }
2629+ case kernel_param_kind_t ::kind_invalid:
2630+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2631+ " Invalid kernel param kind " +
2632+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2633+ break ;
2634+ }
2635+ }
2636+
25222637static std::tuple<ur_kernel_handle_t , device_image_impl *,
25232638 const KernelArgMask *>
25242639getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments