diff --git a/source/adapters/level_zero/memory.cpp b/source/adapters/level_zero/memory.cpp index 82ecd7043b..4757a0563d 100644 --- a/source/adapters/level_zero/memory.cpp +++ b/source/adapters/level_zero/memory.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "context.hpp" #include "event.hpp" @@ -183,9 +184,6 @@ static ur_result_t enqueueMemFillHelper(ur_command_t CommandType, uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList, ur_event_handle_t *OutEvent) { - // Pattern size must be a power of two. - UR_ASSERT((PatternSize > 0) && ((PatternSize & (PatternSize - 1)) == 0), - UR_RESULT_ERROR_INVALID_VALUE); auto &Device = Queue->Device; // Make sure that pattern size matches the capability of the copy queues. @@ -237,18 +235,42 @@ static ur_result_t enqueueMemFillHelper(ur_command_t CommandType, const auto &ZeCommandList = CommandList->first; const auto &WaitList = (*Event)->WaitList; - ZE2UR_CALL(zeCommandListAppendMemoryFill, - (ZeCommandList, Ptr, Pattern, PatternSize, Size, ZeEvent, - WaitList.Length, WaitList.ZeEventList)); + // PatternSize must be a power of two for zeCommandListAppendMemoryFill. + // When it's not, the fill is emulated with zeCommandListAppendMemoryCopy. + if (isPowerOf2(PatternSize)) { + ZE2UR_CALL(zeCommandListAppendMemoryFill, + (ZeCommandList, Ptr, Pattern, PatternSize, Size, ZeEvent, + WaitList.Length, WaitList.ZeEventList)); - logger::debug("calling zeCommandListAppendMemoryFill() with" - " ZeEvent {}", - ur_cast(ZeEvent)); - printZeEventList(WaitList); + logger::debug("calling zeCommandListAppendMemoryFill() with" + " ZeEvent {}", + ur_cast(ZeEvent)); + printZeEventList(WaitList); - // Execute command list asynchronously, as the event will be used - // to track down its completion. - UR_CALL(Queue->executeCommandList(CommandList, false, OkToBatch)); + // Execute command list asynchronously, as the event will be used + // to track down its completion. + UR_CALL(Queue->executeCommandList(CommandList, false, OkToBatch)); + } else { + // Copy pattern into every entry in memory array pointed by Ptr. + uint32_t NumOfCopySteps = Size / PatternSize; + const void *Src = Pattern; + + for (uint32_t step = 0; step < NumOfCopySteps; ++step) { + void *Dst = reinterpret_cast(reinterpret_cast(Ptr) + + step * PatternSize); + ZE2UR_CALL(zeCommandListAppendMemoryCopy, + (ZeCommandList, Dst, Src, PatternSize, ZeEvent, + WaitList.Length, WaitList.ZeEventList)); + } + + logger::debug("calling zeCommandListAppendMemoryCopy() with" + " ZeEvent {}", + ur_cast(ZeEvent)); + printZeEventList(WaitList); + + // Execute command list synchronously. + UR_CALL(Queue->executeCommandList(CommandList, true, OkToBatch)); + } return UR_RESULT_SUCCESS; } diff --git a/source/adapters/opencl/usm.cpp b/source/adapters/opencl/usm.cpp index 0d64f23d13..3f4382fc0d 100644 --- a/source/adapters/opencl/usm.cpp +++ b/source/adapters/opencl/usm.cpp @@ -8,6 +8,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "common.hpp" inline cl_mem_alloc_flags_intel @@ -239,7 +241,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( return mapCLErrorToUR(CLErr); } - if (patternSize <= 128) { + if (patternSize <= 128 && isPowerOf2(patternSize)) { clEnqueueMemFillINTEL_fn EnqueueMemFill = nullptr; UR_RETURN_ON_FAILURE( cl_ext::getExtFuncFromContext(