diff --git a/source/adapters/native_cpu/enqueue.cpp b/source/adapters/native_cpu/enqueue.cpp index 835a7febcf..b5d4713e2f 100644 --- a/source/adapters/native_cpu/enqueue.cpp +++ b/source/adapters/native_cpu/enqueue.cpp @@ -511,8 +511,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( UR_ASSERT(ptr, UR_RESULT_ERROR_INVALID_NULL_POINTER); UR_ASSERT(pPattern, UR_RESULT_ERROR_INVALID_NULL_POINTER); - UR_ASSERT(size % patternSize == 0 || patternSize > size, - UR_RESULT_ERROR_INVALID_SIZE); + UR_ASSERT(patternSize != 0, UR_RESULT_ERROR_INVALID_SIZE) + UR_ASSERT(size != 0, UR_RESULT_ERROR_INVALID_SIZE) + UR_ASSERT(patternSize < size, UR_RESULT_ERROR_INVALID_SIZE) + UR_ASSERT(size % patternSize == 0, UR_RESULT_ERROR_INVALID_SIZE) + // TODO: add check for allocation size once the query is supported switch (patternSize) { case 1: @@ -522,7 +525,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( const auto pattern = *static_cast(pPattern); auto *start = reinterpret_cast(ptr); auto *end = - reinterpret_cast(reinterpret_cast(ptr) + size); + reinterpret_cast(reinterpret_cast(ptr) + size); std::fill(start, end, pattern); break; } @@ -530,7 +533,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( const auto pattern = *static_cast(pPattern); auto *start = reinterpret_cast(ptr); auto *end = - reinterpret_cast(reinterpret_cast(ptr) + size); + reinterpret_cast(reinterpret_cast(ptr) + size); std::fill(start, end, pattern); break; } @@ -538,17 +541,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( const auto pattern = *static_cast(pPattern); auto *start = reinterpret_cast(ptr); auto *end = - reinterpret_cast(reinterpret_cast(ptr) + size); + reinterpret_cast(reinterpret_cast(ptr) + size); std::fill(start, end, pattern); break; } - default: - for (unsigned int step{0}; step < size; ++step) { - auto *dest = reinterpret_cast(reinterpret_cast(ptr) + - step * patternSize); + default: { + for (unsigned int step{0}; step < size; step += patternSize) { + auto *dest = + reinterpret_cast(reinterpret_cast(ptr) + step); memcpy(dest, pPattern, patternSize); } } + } return UR_RESULT_SUCCESS; } @@ -583,7 +587,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( std::ignore = phEventWaitList; std::ignore = phEvent; - DIE_NO_IMPLEMENTATION; + // TODO: properly implement USM prefetch + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL @@ -595,7 +600,8 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, std::ignore = advice; std::ignore = phEvent; - DIE_NO_IMPLEMENTATION; + // TODO: properly implement USM advise + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(