Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions source/adapters/opencl/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
return UR_RESULT_SUCCESS;
}

// OpenCL only supports pattern sizes as large as the largest CL type
// (double16/long16 - 128 bytes), anything larger we need to do on the host
// side and copy it into the target allocation.
// OpenCL only supports pattern sizes which are powers of 2 and are as large
// as the largest CL type (double16/long16 - 128 bytes), anything larger or
// not a power of 2, we need to do on the host side and copy it into the
// target allocation.
clHostMemAllocINTEL_fn HostMemAlloc = nullptr;
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache,
Expand All @@ -275,14 +276,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
cl_ext::MemBlockingFreeName, &USMFree));

cl_int ClErr = CL_SUCCESS;
auto HostBuffer = static_cast<uint64_t *>(
HostMemAlloc(CLContext, nullptr, size, 0, &ClErr));
auto HostBuffer =
static_cast<uint8_t *>(HostMemAlloc(CLContext, nullptr, size, 0, &ClErr));
CL_RETURN_ON_FAILURE(ClErr);

auto NumValues = size / sizeof(uint64_t);
auto NumChunks = patternSize / sizeof(uint64_t);
for (size_t i = 0; i < NumValues; i++) {
HostBuffer[i] = static_cast<const uint64_t *>(pPattern)[i % NumChunks];
auto *End = HostBuffer + size;
for (auto *Iter = HostBuffer; Iter < End; Iter += patternSize) {
std::memcpy(Iter, pPattern, patternSize);
}

cl_event CopyEvent = nullptr;
Expand Down