@@ -2303,14 +2303,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303 }
23042304}
23052305
2306- // Sets arguments for a given kernel and device based on the argument type.
2307- // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2308- // extension.
2309- static void SetArgBasedOnType (
2310- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2306+ // Gets UR argument struct for a given kernel and device based on the argument
2307+ // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2308+ // the graphs extension (LaunchWithArgs for graphs is planned future work).
2309+ static void GetUrArgsBasedOnType (
23112310 device_image_impl *DeviceImageImpl,
23122311 const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2313- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2312+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2313+ std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
23142314 switch (Arg.MType ) {
23152315 case kernel_param_kind_t ::kind_dynamic_work_group_memory:
23162316 break ;
@@ -2330,52 +2330,61 @@ static void SetArgBasedOnType(
23302330 getMemAllocationFunc
23312331 ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
23322332 : nullptr ;
2333- ur_kernel_arg_mem_obj_properties_t MemObjData{};
2334- MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2335- MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2336- Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2337- &MemObjData, MemArg);
2333+ ur_exp_kernel_arg_value_t Value = {};
2334+ Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2335+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2336+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2337+ static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2338+ Value});
23382339 break ;
23392340 }
23402341 case kernel_param_kind_t ::kind_std_layout: {
2342+ ur_exp_kernel_arg_type_t Type;
23412343 if (Arg.MPtr ) {
2342- Adapter.call <UrApiKind::urKernelSetArgValue>(
2343- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2344+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23442345 } else {
2345- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2346- Arg.MSize , nullptr );
2346+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23472347 }
2348+ ur_exp_kernel_arg_value_t Value = {};
2349+ Value.value = {Arg.MPtr };
2350+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2351+ Type, static_cast <uint32_t >(NextTrueIndex),
2352+ static_cast <size_t >(Arg.MSize ), Value});
23482353
23492354 break ;
23502355 }
23512356 case kernel_param_kind_t ::kind_sampler: {
23522357 sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2353- ur_sampler_handle_t Sampler =
2354- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2355- ->getOrCreateSampler (ContextImpl);
2356- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2357- nullptr , Sampler);
2358+ ur_exp_kernel_arg_value_t Value = {};
2359+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2360+ ->getOrCreateSampler (ContextImpl);
2361+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2362+ UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2363+ static_cast <uint32_t >(NextTrueIndex),
2364+ sizeof (ur_sampler_handle_t ), Value});
23582365 break ;
23592366 }
23602367 case kernel_param_kind_t ::kind_pointer: {
2361- // We need to de-rerence this to get the actual USM allocation - that's the
2368+ ur_exp_kernel_arg_value_t Value = {};
2369+ // We need to de-rerence to get the actual USM allocation - that's the
23622370 // pointer UR is expecting.
2363- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2364- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2365- nullptr , Ptr);
2371+ Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2372+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2373+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2374+ static_cast <uint32_t >(NextTrueIndex), sizeof (Arg.MPtr ),
2375+ Value});
23662376 break ;
23672377 }
23682378 case kernel_param_kind_t ::kind_specialization_constants_buffer: {
23692379 assert (DeviceImageImpl != nullptr );
23702380 ur_mem_handle_t SpecConstsBuffer =
23712381 DeviceImageImpl->get_spec_const_buffer_ref ();
2372-
2373- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2374- MemObjProps.pNext = nullptr ;
2375- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2376- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2377- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2378- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2382+ ur_exp_kernel_arg_value_t Value = {};
2383+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2384+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2385+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2386+ static_cast <uint32_t >(NextTrueIndex),
2387+ sizeof (SpecConstsBuffer), Value});
23792388 break ;
23802389 }
23812390 case kernel_param_kind_t ::kind_invalid:
@@ -2404,22 +2413,33 @@ static ur_result_t SetKernelParamsAndLaunch(
24042413 DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
24052414 }
24062415
2416+ // just a performance optimization - avoid heap allocations
2417+ static thread_local std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2418+ UrArgs.reserve (Args.size ());
2419+ UrArgs.clear ();
2420+
24072421 if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures ) {
2408- auto setFunc = [&Adapter, Kernel,
2409- KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2422+ auto setFunc = [KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24102423 size_t NextTrueIndex) {
24112424 const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
24122425 switch (ParamDesc.kind ) {
24132426 case kernel_param_kind_t ::kind_std_layout: {
24142427 int Size = ParamDesc.info ;
2415- Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2416- Size, nullptr , ArgPtr);
2428+ ur_exp_kernel_arg_value_t Value = {};
2429+ Value.value = ArgPtr;
2430+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2431+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2432+ static_cast <uint32_t >(NextTrueIndex),
2433+ static_cast <size_t >(Size), Value});
24172434 break ;
24182435 }
24192436 case kernel_param_kind_t ::kind_pointer: {
2420- const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2421- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2422- nullptr , Ptr);
2437+ ur_exp_kernel_arg_value_t Value = {};
2438+ Value.pointer = *static_cast <const void *const *>(ArgPtr);
2439+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2440+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2441+ static_cast <uint32_t >(NextTrueIndex),
2442+ sizeof (Value.pointer ), Value});
24232443 break ;
24242444 }
24252445 default :
@@ -2429,23 +2449,28 @@ static ur_result_t SetKernelParamsAndLaunch(
24292449 applyFuncOnFilteredArgs (EliminatedArgMask, DeviceKernelInfo.NumParams ,
24302450 DeviceKernelInfo.ParamDescGetter , setFunc);
24312451 } else {
2432- auto setFunc = [&Adapter, Kernel, & DeviceImageImpl, &getMemAllocationFunc,
2452+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc,
24332453 &Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2434- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2435- Queue.getContextImpl (), Arg, NextTrueIndex);
2454+ GetUrArgsBasedOnType ( DeviceImageImpl, getMemAllocationFunc,
2455+ Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs );
24362456 };
24372457 applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
24382458 }
24392459
2440- const std::optional<int > &ImplicitLocalArg =
2441- DeviceKernelInfo.getImplicitLocalArgPos ();
2460+ std::optional<int > ImplicitLocalArg =
2461+ ProgramManager::getInstance ().kernelImplicitLocalArgPos (
2462+ DeviceKernelInfo.Name );
24422463 // Set the implicit local memory buffer to support
24432464 // get_work_group_scratch_memory. This is for backend not supporting
24442465 // CUDA-style local memory setting. Note that we may have -1 as a position,
24452466 // this indicates the buffer is actually unused and was elided.
24462467 if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2447- Adapter.call <UrApiKind::urKernelSetArgLocal>(
2448- Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2468+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2469+ nullptr ,
2470+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2471+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2472+ WorkGroupMemorySize,
2473+ {nullptr }});
24492474 }
24502475
24512476 adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2468,16 +2493,14 @@ static ur_result_t SetKernelParamsAndLaunch(
24682493 /* pPropSizeRet = */ nullptr );
24692494
24702495 const bool EnforcedLocalSize =
2471- (RequiredWGSize[0 ] != 0 &&
2472- (NDRDesc.Dims < 2 || RequiredWGSize[1 ] != 0 ) &&
2473- (NDRDesc.Dims < 3 || RequiredWGSize[2 ] != 0 ));
2496+ (RequiredWGSize[0 ] != 0 || RequiredWGSize[1 ] != 0 ||
2497+ RequiredWGSize[2 ] != 0 );
24742498 if (EnforcedLocalSize)
24752499 LocalSize = RequiredWGSize;
24762500 }
2477-
2478- const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 &&
2479- (NDRDesc.Dims < 2 || NDRDesc.GlobalOffset [1 ] != 0 ) &&
2480- (NDRDesc.Dims < 3 || NDRDesc.GlobalOffset [2 ] != 0 );
2501+ const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 ||
2502+ NDRDesc.GlobalOffset [1 ] != 0 ||
2503+ NDRDesc.GlobalOffset [2 ] != 0 ;
24812504
24822505 std::vector<ur_kernel_launch_property_t > property_list;
24832506
@@ -2505,20 +2528,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25052528 {{WorkGroupMemorySize}}});
25062529 }
25072530 ur_event_handle_t UREvent = nullptr ;
2508- ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2509- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2510- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2511- LocalSize, property_list.size (),
2512- property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2513- RawEvents.empty () ? nullptr : &RawEvents[0 ],
2514- OutEventImpl ? &UREvent : nullptr );
2531+ ur_result_t Error =
2532+ Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2533+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2534+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2535+ &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2536+ property_list.size (),
2537+ property_list.empty () ? nullptr : property_list.data (),
2538+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2539+ OutEventImpl ? &UREvent : nullptr );
25152540 if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25162541 OutEventImpl->setHandle (UREvent);
25172542 }
25182543
25192544 return Error;
25202545}
25212546
2547+ // Sets arguments for a given kernel and device based on the argument type.
2548+ // This is a legacy path which the graphs extension still uses.
2549+ static void SetArgBasedOnType (
2550+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2551+ device_image_impl *DeviceImageImpl,
2552+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2553+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2554+ switch (Arg.MType ) {
2555+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2556+ break ;
2557+ case kernel_param_kind_t ::kind_work_group_memory:
2558+ break ;
2559+ case kernel_param_kind_t ::kind_stream:
2560+ break ;
2561+ case kernel_param_kind_t ::kind_dynamic_accessor:
2562+ case kernel_param_kind_t ::kind_accessor: {
2563+ Requirement *Req = (Requirement *)(Arg.MPtr );
2564+
2565+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2566+ // we may pass default constructed accessors to a command, which don't add
2567+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2568+ // valid case, so we need to properly handle it.
2569+ ur_mem_handle_t MemArg =
2570+ getMemAllocationFunc
2571+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2572+ : nullptr ;
2573+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2574+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2575+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2576+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2577+ &MemObjData, MemArg);
2578+ break ;
2579+ }
2580+ case kernel_param_kind_t ::kind_std_layout: {
2581+ if (Arg.MPtr ) {
2582+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2583+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2584+ } else {
2585+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2586+ Arg.MSize , nullptr );
2587+ }
2588+
2589+ break ;
2590+ }
2591+ case kernel_param_kind_t ::kind_sampler: {
2592+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2593+ ur_sampler_handle_t Sampler =
2594+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2595+ ->getOrCreateSampler (ContextImpl);
2596+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2597+ nullptr , Sampler);
2598+ break ;
2599+ }
2600+ case kernel_param_kind_t ::kind_pointer: {
2601+ // We need to de-rerence this to get the actual USM allocation - that's the
2602+ // pointer UR is expecting.
2603+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2604+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2605+ nullptr , Ptr);
2606+ break ;
2607+ }
2608+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2609+ assert (DeviceImageImpl != nullptr );
2610+ ur_mem_handle_t SpecConstsBuffer =
2611+ DeviceImageImpl->get_spec_const_buffer_ref ();
2612+
2613+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2614+ MemObjProps.pNext = nullptr ;
2615+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2616+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2617+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2618+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2619+ break ;
2620+ }
2621+ case kernel_param_kind_t ::kind_invalid:
2622+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2623+ " Invalid kernel param kind " +
2624+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2625+ break ;
2626+ }
2627+ }
2628+
25222629static std::tuple<ur_kernel_handle_t , device_image_impl *,
25232630 const KernelArgMask *>
25242631getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments