From 09d506456844d5cd789ed85abd0821eb86433fb7 Mon Sep 17 00:00:00 2001 From: Artur Gainullin Date: Mon, 22 Sep 2025 16:09:11 -0700 Subject: [PATCH] [UR][L0] Set pointer kernel arguments only for queue's associated device Ensure that pointer kernel arguments are set only for the device associated with the queue being used for kernel launch. Previously, arguments were set for all devices in the kernel's device map, which was unnecessary and potentially incorrect when launching on a specific device. --- sycl/test-e2e/MultiDevice/set_arg_pointer.cpp | 63 ++++++++++++++++++ .../adapters/level_zero/command_buffer.cpp | 14 ++-- .../level_zero/helpers/kernel_helpers.cpp | 22 +++++++ .../level_zero/helpers/kernel_helpers.hpp | 12 ++++ .../source/adapters/level_zero/kernel.cpp | 65 ++++++++++--------- .../source/adapters/level_zero/kernel.hpp | 7 +- 6 files changed, 145 insertions(+), 38 deletions(-) create mode 100644 sycl/test-e2e/MultiDevice/set_arg_pointer.cpp diff --git a/sycl/test-e2e/MultiDevice/set_arg_pointer.cpp b/sycl/test-e2e/MultiDevice/set_arg_pointer.cpp new file mode 100644 index 0000000000000..39b7ceb79488e --- /dev/null +++ b/sycl/test-e2e/MultiDevice/set_arg_pointer.cpp @@ -0,0 +1,63 @@ +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out + +// UNSUPPORTED: level_zero_v2_adapter +// UNSUPPORTED-TRACKER: CMPLRLLVM-67039 + +// Test that usm device pointer can be used in a kernel compiled for a context +// with multiple devices. + +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +class AddIdxKernel; + +int main() { + sycl::platform plt; + std::vector devices = plt.get_devices(); + if (devices.size() < 2) { + std::cout << "Need at least 2 GPU devices for this test.\n"; + return 0; + } + + std::vector ctx_devices{devices[0], devices[1]}; + sycl::context ctx(ctx_devices); + + constexpr size_t N = 16; + std::vector> results(ctx_devices.size(), + std::vector(N, 0)); + + // Create a kernel bundle compiled for both devices in the context + auto kb = sycl::get_kernel_bundle(ctx); + + // For each device, create a queue and run a kernel using device USM + for (size_t i = 0; i < ctx_devices.size(); ++i) { + sycl::queue q(ctx, ctx_devices[i]); + int *data = sycl::malloc_device(N, q); + q.fill(data, 1, N).wait(); + q.submit([&](sycl::handler &h) { + h.use_kernel_bundle(kb); + h.parallel_for( + sycl::range<1>(N), [=](sycl::id<1> idx) { data[idx] += idx[0]; }); + }).wait(); + q.memcpy(results[i].data(), data, N * sizeof(int)).wait(); + sycl::free(data, q); + } + + for (size_t i = 0; i < ctx_devices.size(); ++i) { + std::cout << "Device " << i << " results: "; + for (size_t j = 0; j < N; ++j) { + if (results[i][j] != 1 + static_cast(j)) { + return -1; + } + std::cout << results[i][j] << " "; + } + } + return 0; +} diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/command_buffer.cpp index 687c905417d8b..25d45f7232636 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.cpp @@ -1004,12 +1004,16 @@ ur_result_t setKernelPendingArguments( ze_kernel_handle_t ZeKernel) { // If there are any pending arguments set them now. for (auto &Arg : PendingArguments) { - // The ArgValue may be a NULL pointer in which case a NULL value is used for - // the kernel argument declared as a pointer to global or constant memory. char **ZeHandlePtr = nullptr; - if (Arg.Value) { - UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device, - nullptr, 0u)); + if (auto MemObjPtr = std::get_if(&Arg.Value)) { + ur_mem_handle_t MemObj = *MemObjPtr; + if (MemObj) { + UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device, + nullptr, 0u)); + } + } else { + auto Ptr = const_cast(&std::get(Arg.Value)); + ZeHandlePtr = reinterpret_cast(Ptr); } ZE2UR_CALL(zeKernelSetArgumentValue, (ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); diff --git a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp index 97aac29a84fbe..a8c75e41e44da 100644 --- a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp +++ b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp @@ -156,3 +156,25 @@ ur_result_t calculateKernelWorkDimensions( return UR_RESULT_SUCCESS; } + +ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel, + uint32_t argIndex, size_t argSize, + const void *pArgValue) { + // OpenCL: "the arg_value pointer can be NULL or point to a NULL value + // in which case a NULL value will be used as the value for the argument + // declared as a pointer to global or constant memory in the kernel" + // + // We don't know the type of the argument but it seems that the only time + // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument + // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. + if (argSize == sizeof(void *) && pArgValue && + *(void **)(const_cast(pArgValue)) == nullptr) { + pArgValue = nullptr; + } + + ze_result_t ZeResult = ZE_CALL_NOCHECK( + zeKernelSetArgumentValue, (hZeKernel, argIndex, argSize, pArgValue)); + if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) + return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE; + return ze2urResult(ZeResult); +} diff --git a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp index 5dcf0c9123045..cc090bed8d7d9 100644 --- a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp +++ b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.hpp @@ -71,3 +71,15 @@ inline void postSubmit(ze_kernel_handle_t hZeKernel, zeKernelSetGlobalOffsetExp(hZeKernel, 0, 0, 0); } } + +/** + * Helper to set kernel argument for ze_kernel_handle_t. + * @param[in] hZeKernel The handle to the Level-Zero kernel. + * @param[in] argIndex The index of the argument to set. + * @param[in] argSize The size of the argument to set. + * @param[in] pArgValue The pointer to the argument value. + * @return UR_RESULT_SUCCESS or an error code on failure + */ +ur_result_t setArgValueOnZeKernel(ze_kernel_handle_t hZeKernel, + uint32_t argIndex, size_t argSize, + const void *pArgValue); diff --git a/unified-runtime/source/adapters/level_zero/kernel.cpp b/unified-runtime/source/adapters/level_zero/kernel.cpp index 45b7b087cece5..bcac9cb04c320 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/kernel.cpp @@ -125,16 +125,22 @@ ur_result_t urEnqueueKernelLaunch( // If there are any pending arguments set them now. for (auto &Arg : Kernel->PendingArguments) { - // The ArgValue may be a NULL pointer in which case a NULL value is used for - // the kernel argument declared as a pointer to global or constant memory. + // The Arg.Value can be either a ur_mem_handle_t or a raw pointer + // (const void*). Resolve per-device: for mem handles obtain the device + // specific handle, otherwise pass the raw pointer value. char **ZeHandlePtr = nullptr; - if (Arg.Value) { - UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, - Queue->Device, EventWaitList, - NumEventsInWaitList)); + if (auto MemObjPtr = std::get_if(&Arg.Value)) { + ur_mem_handle_t MemObj = *MemObjPtr; + if (MemObj) { + UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, + Queue->Device, EventWaitList, + NumEventsInWaitList)); + } + } else { + auto Ptr = const_cast(&std::get(Arg.Value)); + ZeHandlePtr = reinterpret_cast(Ptr); } - ZE2UR_CALL(zeKernelSetArgumentValue, - (ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); + UR_CALL(setArgValueOnZeKernel(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); } Kernel->PendingArguments.clear(); @@ -422,41 +428,21 @@ ur_result_t urKernelSetArgValue( UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); - // OpenCL: "the arg_value pointer can be NULL or point to a NULL value - // in which case a NULL value will be used as the value for the argument - // declared as a pointer to global or constant memory in the kernel" - // - // We don't know the type of the argument but it seems that the only time - // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument - // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. - if (ArgSize == sizeof(void *) && PArgValue && - *(void **)(const_cast(PArgValue)) == nullptr) { - PArgValue = nullptr; - } - if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) { return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX; } std::scoped_lock Guard(Kernel->Mutex); - ze_result_t ZeResult = ZE_RESULT_SUCCESS; if (Kernel->ZeKernelMap.empty()) { auto ZeKernel = Kernel->ZeKernel; - ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, - (ZeKernel, ArgIndex, ArgSize, PArgValue)); + UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue)) } else { for (auto It : Kernel->ZeKernelMap) { auto ZeKernel = It.second; - ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, - (ZeKernel, ArgIndex, ArgSize, PArgValue)); + UR_CALL(setArgValueOnZeKernel(ZeKernel, ArgIndex, ArgSize, PArgValue)) } } - - if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) { - return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE; - } - - return ze2urResult(ZeResult); + return UR_RESULT_SUCCESS; } ur_result_t urKernelSetArgLocal( @@ -732,6 +718,23 @@ ur_result_t urKernelSetArgPointer( /// [in][optional] SVM pointer to memory location holding the argument /// value. If null then argument value is considered null. const void *ArgValue) { + UR_ASSERT(Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); + { + std::scoped_lock Guard(Kernel->Mutex); + // In multi-device context instead of setting pointer arguments immediately + // across all device kernels, store them as pending so they can be resolved + // per-device at enqueue time. This ensures the correct handle is used for + // the device of the queue. + if (Kernel->Program->Context->getDevices().size() > 1) { + if (ArgIndex > Kernel->ZeKernelProperties->numKernelArgs - 1) { + return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX; + } + Kernel->PendingArguments.push_back({ArgIndex, sizeof(const void *), + ArgValue, ur_mem_handle_t_::unknown}); + + return UR_RESULT_SUCCESS; + } + } // KernelSetArgValue is expecting a pointer to the argument UR_CALL(ur::level_zero::urKernelSetArgValue( diff --git a/unified-runtime/source/adapters/level_zero/kernel.hpp b/unified-runtime/source/adapters/level_zero/kernel.hpp index 131dba270c05d..38e2e43e366b6 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/kernel.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include "common.hpp" #include "common/ur_ref_count.hpp" @@ -97,8 +98,10 @@ struct ur_kernel_handle_t_ : ur_object { struct ArgumentInfo { uint32_t Index; size_t Size; - // const ur_mem_handle_t_ *Value; - ur_mem_handle_t_ *Value; + // Value may be either a memory object or a raw pointer value (for pointer + // arguments). Resolve at enqueue time per-device to ensure correct handle + // is used for that device. + std::variant Value; ur_mem_handle_t_::access_mode_t AccessMode{ur_mem_handle_t_::unknown}; }; // Arguments that still need to be set (with zeKernelSetArgumentValue)