@@ -2303,14 +2303,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23032303 }
23042304}
23052305
2306- // Sets arguments for a given kernel and device based on the argument type.
2307- // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2308- // extension.
2309- static void SetArgBasedOnType (
2310- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2306+ // Gets UR argument struct for a given kernel and device based on the argument
2307+ // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2308+ // the graphs extension (LaunchWithArgs for graphs is planned future work).
2309+ static void GetUrArgsBasedOnType (
23112310 device_image_impl *DeviceImageImpl,
23122311 const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2313- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2312+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2313+ std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
23142314 switch (Arg.MType ) {
23152315 case kernel_param_kind_t ::kind_dynamic_work_group_memory:
23162316 break ;
@@ -2330,52 +2330,61 @@ static void SetArgBasedOnType(
23302330 getMemAllocationFunc
23312331 ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
23322332 : nullptr ;
2333- ur_kernel_arg_mem_obj_properties_t MemObjData{};
2334- MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2335- MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2336- Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2337- &MemObjData, MemArg);
2333+ ur_exp_kernel_arg_value_t Value = {};
2334+ Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2335+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2336+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2337+ static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2338+ Value});
23382339 break ;
23392340 }
23402341 case kernel_param_kind_t ::kind_std_layout: {
2342+ ur_exp_kernel_arg_type_t Type;
23412343 if (Arg.MPtr ) {
2342- Adapter.call <UrApiKind::urKernelSetArgValue>(
2343- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2344+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
23442345 } else {
2345- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2346- Arg.MSize , nullptr );
2346+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
23472347 }
2348+ ur_exp_kernel_arg_value_t Value = {};
2349+ Value.value = {Arg.MPtr };
2350+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2351+ Type, static_cast <uint32_t >(NextTrueIndex),
2352+ static_cast <size_t >(Arg.MSize ), Value});
23482353
23492354 break ;
23502355 }
23512356 case kernel_param_kind_t ::kind_sampler: {
23522357 sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2353- ur_sampler_handle_t Sampler =
2354- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2355- ->getOrCreateSampler (ContextImpl);
2356- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2357- nullptr , Sampler);
2358+ ur_exp_kernel_arg_value_t Value = {};
2359+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2360+ ->getOrCreateSampler (ContextImpl);
2361+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2362+ UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2363+ static_cast <uint32_t >(NextTrueIndex),
2364+ sizeof (ur_sampler_handle_t ), Value});
23582365 break ;
23592366 }
23602367 case kernel_param_kind_t ::kind_pointer: {
2361- // We need to de-rerence this to get the actual USM allocation - that's the
2368+ ur_exp_kernel_arg_value_t Value = {};
2369+ // We need to de-rerence to get the actual USM allocation - that's the
23622370 // pointer UR is expecting.
2363- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2364- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2365- nullptr , Ptr);
2371+ Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2372+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2373+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2374+ static_cast <uint32_t >(NextTrueIndex), sizeof (Arg.MPtr ),
2375+ Value});
23662376 break ;
23672377 }
23682378 case kernel_param_kind_t ::kind_specialization_constants_buffer: {
23692379 assert (DeviceImageImpl != nullptr );
23702380 ur_mem_handle_t SpecConstsBuffer =
23712381 DeviceImageImpl->get_spec_const_buffer_ref ();
2372-
2373- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2374- MemObjProps.pNext = nullptr ;
2375- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2376- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2377- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2378- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2382+ ur_exp_kernel_arg_value_t Value = {};
2383+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2384+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2385+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2386+ static_cast <uint32_t >(NextTrueIndex),
2387+ sizeof (SpecConstsBuffer), Value});
23792388 break ;
23802389 }
23812390 case kernel_param_kind_t ::kind_invalid:
@@ -2404,22 +2413,32 @@ static ur_result_t SetKernelParamsAndLaunch(
24042413 DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
24052414 }
24062415
2416+ std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2417+ UrArgs.reserve (Args.size ());
2418+
24072419 if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures ) {
2408- auto setFunc = [&Adapter, Kernel ,
2420+ auto setFunc = [&UrArgs ,
24092421 KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24102422 size_t NextTrueIndex) {
24112423 const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
24122424 switch (ParamDesc.kind ) {
24132425 case kernel_param_kind_t ::kind_std_layout: {
24142426 int Size = ParamDesc.info ;
2415- Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2416- Size, nullptr , ArgPtr);
2427+ ur_exp_kernel_arg_value_t Value = {};
2428+ Value.value = ArgPtr;
2429+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2430+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2431+ static_cast <uint32_t >(NextTrueIndex),
2432+ static_cast <size_t >(Size), Value});
24172433 break ;
24182434 }
24192435 case kernel_param_kind_t ::kind_pointer: {
2420- const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2421- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2422- nullptr , Ptr);
2436+ ur_exp_kernel_arg_value_t Value = {};
2437+ Value.pointer = *static_cast <const void *const *>(ArgPtr);
2438+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2439+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2440+ static_cast <uint32_t >(NextTrueIndex),
2441+ sizeof (Value.pointer ), Value});
24232442 break ;
24242443 }
24252444 default :
@@ -2429,23 +2448,28 @@ static ur_result_t SetKernelParamsAndLaunch(
24292448 applyFuncOnFilteredArgs (EliminatedArgMask, DeviceKernelInfo.NumParams ,
24302449 DeviceKernelInfo.ParamDescGetter , setFunc);
24312450 } else {
2432- auto setFunc = [&Adapter, Kernel, &DeviceImageImpl , &getMemAllocationFunc ,
2433- &Queue ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2434- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2435- Queue.getContextImpl (), Arg, NextTrueIndex);
2451+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc , &Queue ,
2452+ &UrArgs ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2453+ GetUrArgsBasedOnType ( DeviceImageImpl, getMemAllocationFunc,
2454+ Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs );
24362455 };
24372456 applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
24382457 }
24392458
2440- const std::optional<int > &ImplicitLocalArg =
2441- DeviceKernelInfo.getImplicitLocalArgPos ();
2459+ std::optional<int > ImplicitLocalArg =
2460+ ProgramManager::getInstance ().kernelImplicitLocalArgPos (
2461+ DeviceKernelInfo.Name );
24422462 // Set the implicit local memory buffer to support
24432463 // get_work_group_scratch_memory. This is for backend not supporting
24442464 // CUDA-style local memory setting. Note that we may have -1 as a position,
24452465 // this indicates the buffer is actually unused and was elided.
24462466 if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2447- Adapter.call <UrApiKind::urKernelSetArgLocal>(
2448- Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2467+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2468+ nullptr ,
2469+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2470+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2471+ WorkGroupMemorySize,
2472+ {nullptr }});
24492473 }
24502474
24512475 adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2468,16 +2492,14 @@ static ur_result_t SetKernelParamsAndLaunch(
24682492 /* pPropSizeRet = */ nullptr );
24692493
24702494 const bool EnforcedLocalSize =
2471- (RequiredWGSize[0 ] != 0 &&
2472- (NDRDesc.Dims < 2 || RequiredWGSize[1 ] != 0 ) &&
2473- (NDRDesc.Dims < 3 || RequiredWGSize[2 ] != 0 ));
2495+ (RequiredWGSize[0 ] != 0 || RequiredWGSize[1 ] != 0 ||
2496+ RequiredWGSize[2 ] != 0 );
24742497 if (EnforcedLocalSize)
24752498 LocalSize = RequiredWGSize;
24762499 }
2477-
2478- const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 &&
2479- (NDRDesc.Dims < 2 || NDRDesc.GlobalOffset [1 ] != 0 ) &&
2480- (NDRDesc.Dims < 3 || NDRDesc.GlobalOffset [2 ] != 0 );
2500+ const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 ||
2501+ NDRDesc.GlobalOffset [1 ] != 0 ||
2502+ NDRDesc.GlobalOffset [2 ] != 0 ;
24812503
24822504 std::vector<ur_kernel_launch_property_t > property_list;
24832505
@@ -2505,20 +2527,104 @@ static ur_result_t SetKernelParamsAndLaunch(
25052527 {{WorkGroupMemorySize}}});
25062528 }
25072529 ur_event_handle_t UREvent = nullptr ;
2508- ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2509- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2510- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2511- LocalSize, property_list.size (),
2512- property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2513- RawEvents.empty () ? nullptr : &RawEvents[0 ],
2514- OutEventImpl ? &UREvent : nullptr );
2530+ ur_result_t Error =
2531+ Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2532+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2533+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2534+ &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2535+ property_list.size (),
2536+ property_list.empty () ? nullptr : property_list.data (),
2537+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2538+ OutEventImpl ? &UREvent : nullptr );
25152539 if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25162540 OutEventImpl->setHandle (UREvent);
25172541 }
25182542
25192543 return Error;
25202544}
25212545
2546+ // Sets arguments for a given kernel and device based on the argument type.
2547+ // This is a legacy path which the graphs extension still uses.
2548+ static void SetArgBasedOnType (
2549+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2550+ device_image_impl *DeviceImageImpl,
2551+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2552+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2553+ switch (Arg.MType ) {
2554+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2555+ break ;
2556+ case kernel_param_kind_t ::kind_work_group_memory:
2557+ break ;
2558+ case kernel_param_kind_t ::kind_stream:
2559+ break ;
2560+ case kernel_param_kind_t ::kind_dynamic_accessor:
2561+ case kernel_param_kind_t ::kind_accessor: {
2562+ Requirement *Req = (Requirement *)(Arg.MPtr );
2563+
2564+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2565+ // we may pass default constructed accessors to a command, which don't add
2566+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2567+ // valid case, so we need to properly handle it.
2568+ ur_mem_handle_t MemArg =
2569+ getMemAllocationFunc
2570+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2571+ : nullptr ;
2572+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2573+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2574+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2575+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2576+ &MemObjData, MemArg);
2577+ break ;
2578+ }
2579+ case kernel_param_kind_t ::kind_std_layout: {
2580+ if (Arg.MPtr ) {
2581+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2582+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2583+ } else {
2584+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2585+ Arg.MSize , nullptr );
2586+ }
2587+
2588+ break ;
2589+ }
2590+ case kernel_param_kind_t ::kind_sampler: {
2591+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2592+ ur_sampler_handle_t Sampler =
2593+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2594+ ->getOrCreateSampler (ContextImpl);
2595+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2596+ nullptr , Sampler);
2597+ break ;
2598+ }
2599+ case kernel_param_kind_t ::kind_pointer: {
2600+ // We need to de-rerence this to get the actual USM allocation - that's the
2601+ // pointer UR is expecting.
2602+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2603+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2604+ nullptr , Ptr);
2605+ break ;
2606+ }
2607+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2608+ assert (DeviceImageImpl != nullptr );
2609+ ur_mem_handle_t SpecConstsBuffer =
2610+ DeviceImageImpl->get_spec_const_buffer_ref ();
2611+
2612+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2613+ MemObjProps.pNext = nullptr ;
2614+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2615+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2616+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2617+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2618+ break ;
2619+ }
2620+ case kernel_param_kind_t ::kind_invalid:
2621+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2622+ " Invalid kernel param kind " +
2623+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2624+ break ;
2625+ }
2626+ }
2627+
25222628static std::tuple<ur_kernel_handle_t , device_image_impl *,
25232629 const KernelArgMask *>
25242630getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments