-
Notifications
You must be signed in to change notification settings - Fork 794
[UR][SYCL] Add support for zeCommandListAppendLaunchKernelWithArguments()
#20316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2303,14 +2303,22 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) { | |
| } | ||
| } | ||
|
|
||
| // Sets arguments for a given kernel and device based on the argument type. | ||
| // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs | ||
| // extension. | ||
| static void SetArgBasedOnType( | ||
| adapter_impl &Adapter, ur_kernel_handle_t Kernel, | ||
| // Gets UR argument struct for a given kernel and device based on the argument | ||
| // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in | ||
| // the graphs extension (LaunchWithArgs for graphs is planned future work). | ||
| static void GetUrArgsBasedOnType( | ||
| device_image_impl *DeviceImageImpl, | ||
| const std::function<void *(Requirement *Req)> &getMemAllocationFunc, | ||
| context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) { | ||
| context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex, | ||
| std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) { | ||
| // UrArg.size == 0 indicates uninitialized structure | ||
| ur_exp_kernel_arg_properties_t UrArg = { | ||
| UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, | ||
| nullptr, | ||
| UR_EXP_KERNEL_ARG_TYPE_VALUE, | ||
| static_cast<uint32_t>(NextTrueIndex), | ||
| 0, | ||
| {}}; | ||
| switch (Arg.MType) { | ||
| case kernel_param_kind_t::kind_dynamic_work_group_memory: | ||
| break; | ||
|
|
@@ -2330,52 +2338,56 @@ static void SetArgBasedOnType( | |
| getMemAllocationFunc | ||
| ? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req)) | ||
| : nullptr; | ||
| ur_kernel_arg_mem_obj_properties_t MemObjData{}; | ||
| MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES; | ||
| MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode); | ||
| Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex, | ||
| &MemObjData, MemArg); | ||
| ur_exp_kernel_arg_value_t Value = {}; | ||
| Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)}; | ||
| UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ; | ||
| UrArg.size = sizeof(MemArg); | ||
| UrArg.value = Value; | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_std_layout: { | ||
| ur_exp_kernel_arg_type_t Type; | ||
| if (Arg.MPtr) { | ||
| Adapter.call<UrApiKind::urKernelSetArgValue>( | ||
| Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr); | ||
| Type = UR_EXP_KERNEL_ARG_TYPE_VALUE; | ||
| } else { | ||
| Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex, | ||
| Arg.MSize, nullptr); | ||
| Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL; | ||
| } | ||
|
|
||
| ur_exp_kernel_arg_value_t Value = {}; | ||
| Value.value = {Arg.MPtr}; | ||
| UrArg.type = Type; | ||
| UrArg.size = static_cast<size_t>(Arg.MSize); | ||
| UrArg.value = Value; | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_sampler: { | ||
| sampler *SamplerPtr = (sampler *)Arg.MPtr; | ||
| ur_sampler_handle_t Sampler = | ||
| (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr) | ||
| ->getOrCreateSampler(ContextImpl); | ||
| Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex, | ||
| nullptr, Sampler); | ||
| ur_exp_kernel_arg_value_t Value = {}; | ||
| Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr) | ||
| ->getOrCreateSampler(ContextImpl); | ||
| UrArg.type = UR_EXP_KERNEL_ARG_TYPE_SAMPLER; | ||
| UrArg.size = sizeof(ur_sampler_handle_t); | ||
| UrArg.value = Value; | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_pointer: { | ||
| // We need to de-rerence this to get the actual USM allocation - that's the | ||
| ur_exp_kernel_arg_value_t Value = {}; | ||
| // We need to de-rerence to get the actual USM allocation - that's the | ||
| // pointer UR is expecting. | ||
| const void *Ptr = *static_cast<const void *const *>(Arg.MPtr); | ||
| Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex, | ||
| nullptr, Ptr); | ||
| Value.pointer = *static_cast<void *const *>(Arg.MPtr); | ||
| UrArg.type = UR_EXP_KERNEL_ARG_TYPE_POINTER; | ||
| UrArg.size = sizeof(Arg.MPtr); | ||
| UrArg.value = Value; | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_specialization_constants_buffer: { | ||
| assert(DeviceImageImpl != nullptr); | ||
| ur_mem_handle_t SpecConstsBuffer = | ||
| DeviceImageImpl->get_spec_const_buffer_ref(); | ||
|
|
||
| ur_kernel_arg_mem_obj_properties_t MemObjProps{}; | ||
| MemObjProps.pNext = nullptr; | ||
| MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES; | ||
| MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY; | ||
| Adapter.call<UrApiKind::urKernelSetArgMemObj>( | ||
| Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer); | ||
| ur_exp_kernel_arg_value_t Value = {}; | ||
| Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY}; | ||
| UrArg.type = UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ; | ||
| UrArg.size = sizeof(SpecConstsBuffer); | ||
| UrArg.value = Value; | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_invalid: | ||
|
|
@@ -2384,6 +2396,10 @@ static void SetArgBasedOnType( | |
| codeToString(UR_RESULT_ERROR_INVALID_VALUE)); | ||
| break; | ||
| } | ||
|
|
||
| if (UrArg.size) { | ||
| UrArgs.push_back(UrArg); | ||
| } | ||
| } | ||
|
|
||
| static ur_result_t SetKernelParamsAndLaunch( | ||
|
|
@@ -2404,22 +2420,33 @@ static ur_result_t SetKernelParamsAndLaunch( | |
| DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty); | ||
| } | ||
|
|
||
| // just a performance optimization - avoid heap allocations | ||
| static thread_local std::vector<ur_exp_kernel_arg_properties_t> UrArgs; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line introduce another perf regression. Since we are using TLS here there are a lot of calls to the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Originally there was: 24f54ea#diff-654cc430ad3aa564d94299199b9b53ce9cb48b2301edfa8890d6bd509dd23e55R2432-R2433 We have tried to fix it: #20316 (comment) but we have no other idea either... |
||
| UrArgs.clear(); | ||
| UrArgs.reserve(Args.size()); | ||
|
|
||
| if (KernelFuncPtr && !DeviceKernelInfo.HasSpecialCaptures) { | ||
| auto setFunc = [&Adapter, Kernel, | ||
| KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc, | ||
| auto setFunc = [KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc, | ||
| size_t NextTrueIndex) { | ||
| const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset; | ||
| switch (ParamDesc.kind) { | ||
| case kernel_param_kind_t::kind_std_layout: { | ||
| int Size = ParamDesc.info; | ||
| Adapter.call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex, | ||
| Size, nullptr, ArgPtr); | ||
| ur_exp_kernel_arg_value_t Value = {}; | ||
| Value.value = ArgPtr; | ||
| UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr, | ||
| UR_EXP_KERNEL_ARG_TYPE_VALUE, | ||
| static_cast<uint32_t>(NextTrueIndex), | ||
| static_cast<size_t>(Size), Value}); | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_pointer: { | ||
| const void *Ptr = *static_cast<const void *const *>(ArgPtr); | ||
| Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex, | ||
| nullptr, Ptr); | ||
| ur_exp_kernel_arg_value_t Value = {}; | ||
| Value.pointer = *static_cast<const void *const *>(ArgPtr); | ||
| UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr, | ||
| UR_EXP_KERNEL_ARG_TYPE_POINTER, | ||
| static_cast<uint32_t>(NextTrueIndex), | ||
| sizeof(Value.pointer), Value}); | ||
| break; | ||
| } | ||
| default: | ||
|
|
@@ -2429,23 +2456,28 @@ static ur_result_t SetKernelParamsAndLaunch( | |
| applyFuncOnFilteredArgs(EliminatedArgMask, DeviceKernelInfo.NumParams, | ||
| DeviceKernelInfo.ParamDescGetter, setFunc); | ||
| } else { | ||
| auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc, | ||
| auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, | ||
| &Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) { | ||
| SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc, | ||
| Queue.getContextImpl(), Arg, NextTrueIndex); | ||
| GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc, | ||
| Queue.getContextImpl(), Arg, NextTrueIndex, UrArgs); | ||
| }; | ||
| applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc); | ||
| } | ||
|
|
||
| const std::optional<int> &ImplicitLocalArg = | ||
| DeviceKernelInfo.getImplicitLocalArgPos(); | ||
| std::optional<int> ImplicitLocalArg = | ||
| ProgramManager::getInstance().kernelImplicitLocalArgPos( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is a part of the 49e4eb2 commit that is a revert of revert of #18764. I made it by accident during this revert, because it was present on the So it is a mistake made during the revert of a very huge commit.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| DeviceKernelInfo.Name); | ||
| // Set the implicit local memory buffer to support | ||
| // get_work_group_scratch_memory. This is for backend not supporting | ||
| // CUDA-style local memory setting. Note that we may have -1 as a position, | ||
| // this indicates the buffer is actually unused and was elided. | ||
| if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) { | ||
| Adapter.call<UrApiKind::urKernelSetArgLocal>( | ||
| Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr); | ||
| UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, | ||
| nullptr, | ||
| UR_EXP_KERNEL_ARG_TYPE_LOCAL, | ||
| static_cast<uint32_t>(ImplicitLocalArg.value()), | ||
| WorkGroupMemorySize, | ||
| {nullptr}}); | ||
| } | ||
|
|
||
| adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl()); | ||
|
|
@@ -2468,16 +2500,14 @@ static ur_result_t SetKernelParamsAndLaunch( | |
| /* pPropSizeRet = */ nullptr); | ||
|
|
||
| const bool EnforcedLocalSize = | ||
| (RequiredWGSize[0] != 0 && | ||
| (NDRDesc.Dims < 2 || RequiredWGSize[1] != 0) && | ||
| (NDRDesc.Dims < 3 || RequiredWGSize[2] != 0)); | ||
| (RequiredWGSize[0] != 0 || RequiredWGSize[1] != 0 || | ||
| RequiredWGSize[2] != 0); | ||
| if (EnforcedLocalSize) | ||
| LocalSize = RequiredWGSize; | ||
| } | ||
|
|
||
| const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 && | ||
| (NDRDesc.Dims < 2 || NDRDesc.GlobalOffset[1] != 0) && | ||
| (NDRDesc.Dims < 3 || NDRDesc.GlobalOffset[2] != 0); | ||
| const bool HasOffset = NDRDesc.GlobalOffset[0] != 0 || | ||
| NDRDesc.GlobalOffset[1] != 0 || | ||
| NDRDesc.GlobalOffset[2] != 0; | ||
|
|
||
| std::vector<ur_kernel_launch_property_t> property_list; | ||
|
|
||
|
|
@@ -2505,20 +2535,104 @@ static ur_result_t SetKernelParamsAndLaunch( | |
| {{WorkGroupMemorySize}}}); | ||
| } | ||
| ur_event_handle_t UREvent = nullptr; | ||
| ur_result_t Error = Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunch>( | ||
| Queue.getHandleRef(), Kernel, NDRDesc.Dims, | ||
| HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0], | ||
| LocalSize, property_list.size(), | ||
| property_list.empty() ? nullptr : property_list.data(), RawEvents.size(), | ||
| RawEvents.empty() ? nullptr : &RawEvents[0], | ||
| OutEventImpl ? &UREvent : nullptr); | ||
| ur_result_t Error = | ||
| Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>( | ||
| Queue.getHandleRef(), Kernel, NDRDesc.Dims, | ||
| HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, | ||
| &NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(), | ||
| property_list.size(), | ||
| property_list.empty() ? nullptr : property_list.data(), | ||
| RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0], | ||
| OutEventImpl ? &UREvent : nullptr); | ||
| if (Error == UR_RESULT_SUCCESS && OutEventImpl) { | ||
| OutEventImpl->setHandle(UREvent); | ||
| } | ||
|
|
||
| return Error; | ||
| } | ||
|
|
||
| // Sets arguments for a given kernel and device based on the argument type. | ||
| // This is a legacy path which the graphs extension still uses. | ||
| static void SetArgBasedOnType( | ||
| adapter_impl &Adapter, ur_kernel_handle_t Kernel, | ||
| device_image_impl *DeviceImageImpl, | ||
| const std::function<void *(Requirement *Req)> &getMemAllocationFunc, | ||
| context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) { | ||
| switch (Arg.MType) { | ||
| case kernel_param_kind_t::kind_dynamic_work_group_memory: | ||
| break; | ||
| case kernel_param_kind_t::kind_work_group_memory: | ||
| break; | ||
| case kernel_param_kind_t::kind_stream: | ||
| break; | ||
| case kernel_param_kind_t::kind_dynamic_accessor: | ||
| case kernel_param_kind_t::kind_accessor: { | ||
| Requirement *Req = (Requirement *)(Arg.MPtr); | ||
|
|
||
| // getMemAllocationFunc is nullptr when there are no requirements. However, | ||
| // we may pass default constructed accessors to a command, which don't add | ||
| // requirements. In such case, getMemAllocationFunc is nullptr, but it's a | ||
| // valid case, so we need to properly handle it. | ||
| ur_mem_handle_t MemArg = | ||
| getMemAllocationFunc | ||
| ? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req)) | ||
| : nullptr; | ||
| ur_kernel_arg_mem_obj_properties_t MemObjData{}; | ||
| MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES; | ||
| MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode); | ||
| Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex, | ||
| &MemObjData, MemArg); | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_std_layout: { | ||
| if (Arg.MPtr) { | ||
| Adapter.call<UrApiKind::urKernelSetArgValue>( | ||
| Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr); | ||
| } else { | ||
| Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex, | ||
| Arg.MSize, nullptr); | ||
| } | ||
|
|
||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_sampler: { | ||
| sampler *SamplerPtr = (sampler *)Arg.MPtr; | ||
| ur_sampler_handle_t Sampler = | ||
| (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr) | ||
| ->getOrCreateSampler(ContextImpl); | ||
| Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex, | ||
| nullptr, Sampler); | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_pointer: { | ||
| // We need to de-rerence this to get the actual USM allocation - that's the | ||
| // pointer UR is expecting. | ||
| const void *Ptr = *static_cast<const void *const *>(Arg.MPtr); | ||
| Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex, | ||
| nullptr, Ptr); | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_specialization_constants_buffer: { | ||
| assert(DeviceImageImpl != nullptr); | ||
| ur_mem_handle_t SpecConstsBuffer = | ||
| DeviceImageImpl->get_spec_const_buffer_ref(); | ||
|
|
||
| ur_kernel_arg_mem_obj_properties_t MemObjProps{}; | ||
| MemObjProps.pNext = nullptr; | ||
| MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES; | ||
| MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY; | ||
| Adapter.call<UrApiKind::urKernelSetArgMemObj>( | ||
| Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer); | ||
| break; | ||
| } | ||
| case kernel_param_kind_t::kind_invalid: | ||
| throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), | ||
| "Invalid kernel param kind " + | ||
| codeToString(UR_RESULT_ERROR_INVALID_VALUE)); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| static std::tuple<ur_kernel_handle_t, device_image_impl *, | ||
| const KernelArgMask *> | ||
| getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.