Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions unified-runtime/source/adapters/offload/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ struct ur_context_handle_t_ : RefCounted {
~ur_context_handle_t_() { urDeviceRelease(Device); }

ur_device_handle_t Device;
std::unordered_map<void *, alloc_info_t> AllocTypeMap;

std::optional<alloc_info_t> getAllocType(const void *UsmPtr) {
for (auto &pair : AllocTypeMap) {
if (UsmPtr >= pair.first &&
reinterpret_cast<uintptr_t>(UsmPtr) <
reinterpret_cast<uintptr_t>(pair.first) + pair.second.Size) {
return pair.second;
}
ol_result_t getAllocType(const void *UsmPtr, ol_alloc_type_t &Type) {
auto Err = olGetMemInfo(UsmPtr, OL_MEM_INFO_TYPE, sizeof(Type), &Type);
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
// Treat unknown allocations as host
Type = OL_ALLOC_TYPE_HOST;
return OL_SUCCESS;
}
return std::nullopt;
return Err;
}
};
25 changes: 12 additions & 13 deletions unified-runtime/source/adapters/offload/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,19 +440,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
size_t size, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
auto GetDevice = [&](const void *Ptr) {
auto Res = hQueue->UrContext->getAllocType(Ptr);
if (!Res)
return Adapter->HostDevice;
return Res->Type == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice
: hQueue->OffloadDevice;
};

return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, GetDevice(pDst), pSrc,
GetDevice(pSrc), size, blocking, numEventsInWaitList,
phEventWaitList, phEvent);

return UR_RESULT_SUCCESS;
ol_alloc_type_t DstTy;
OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pDst, DstTy));
ol_device_handle_t Dst =
DstTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice;

ol_alloc_type_t SrcTy;
OL_RETURN_ON_ERR(hQueue->UrContext->getAllocType(pSrc, SrcTy));
ol_device_handle_t Src =
SrcTy == OL_ALLOC_TYPE_HOST ? Adapter->HostDevice : hQueue->OffloadDevice;

return doMemcpy(UR_COMMAND_USM_MEMCPY, hQueue, pDst, Dst, pSrc, Src, size,
blocking, numEventsInWaitList, phEventWaitList, phEvent);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMAdvise(
Expand Down
89 changes: 70 additions & 19 deletions unified-runtime/source/adapters/offload/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
size_t size, void **ppMem) {
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
OL_ALLOC_TYPE_HOST, size, ppMem));

hContext->AllocTypeMap.insert_or_assign(
*ppMem, alloc_info_t{OL_ALLOC_TYPE_HOST, size});
return UR_RESULT_SUCCESS;
}

Expand All @@ -33,9 +30,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
ur_usm_pool_handle_t, size_t size, void **ppMem) {
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
OL_ALLOC_TYPE_DEVICE, size, ppMem));

hContext->AllocTypeMap.insert_or_assign(
*ppMem, alloc_info_t{OL_ALLOC_TYPE_DEVICE, size});
return UR_RESULT_SUCCESS;
}

Expand All @@ -44,23 +38,80 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
ur_usm_pool_handle_t, size_t size, void **ppMem) {
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
OL_ALLOC_TYPE_MANAGED, size, ppMem));

hContext->AllocTypeMap.insert_or_assign(
*ppMem, alloc_info_t{OL_ALLOC_TYPE_MANAGED, size});
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
void *pMem) {
hContext->AllocTypeMap.erase(pMem);
UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t, void *pMem) {
return offloadResultToUR(olMemFree(pMem));
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
[[maybe_unused]] ur_context_handle_t hContext,
[[maybe_unused]] const void *pMem,
[[maybe_unused]] ur_usm_alloc_info_t propName,
[[maybe_unused]] size_t propSize, [[maybe_unused]] void *pPropValue,
[[maybe_unused]] size_t *pPropSizeRet) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
UR_APIEXPORT ur_result_t UR_APICALL
urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
ur_usm_alloc_info_t propName, size_t propSize,
void *pPropValue, size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

ol_mem_info_t olInfo;

switch (propName) {
case UR_USM_ALLOC_INFO_TYPE:
olInfo = OL_MEM_INFO_TYPE;
break;
case UR_USM_ALLOC_INFO_BASE_PTR:
olInfo = OL_MEM_INFO_BASE;
break;
case UR_USM_ALLOC_INFO_SIZE:
olInfo = OL_MEM_INFO_SIZE;
break;
case UR_USM_ALLOC_INFO_DEVICE:
// Contexts can only contain one device
return ReturnValue(hContext->Device);
case UR_USM_ALLOC_INFO_POOL:
default:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
break;
}

if (pPropSizeRet) {
OL_RETURN_ON_ERR(olGetMemInfoSize(pMem, olInfo, pPropSizeRet));
}

if (pPropValue) {
auto Err = olGetMemInfo(pMem, olInfo, propSize, pPropValue);
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
// If the device didn't allocate this object, return default values
switch (propName) {
case UR_USM_ALLOC_INFO_TYPE:
return ReturnValue(UR_USM_TYPE_UNKNOWN);
case UR_USM_ALLOC_INFO_BASE_PTR:
return ReturnValue(nullptr);
case UR_USM_ALLOC_INFO_SIZE:
return ReturnValue(0);
default:
return UR_RESULT_ERROR_UNKNOWN;
}
}
OL_RETURN_ON_ERR(Err);

if (propName == UR_USM_ALLOC_INFO_TYPE) {
auto *OlType = reinterpret_cast<ol_alloc_type_t *>(pPropValue);
auto *UrType = reinterpret_cast<ur_usm_type_t *>(pPropValue);
switch (*OlType) {
case OL_ALLOC_TYPE_HOST:
*UrType = UR_USM_TYPE_HOST;
break;
case OL_ALLOC_TYPE_DEVICE:
*UrType = UR_USM_TYPE_DEVICE;
break;
case OL_ALLOC_TYPE_MANAGED:
*UrType = UR_USM_TYPE_SHARED;
break;
default:
*UrType = UR_USM_TYPE_UNKNOWN;
break;
}
}
}

return UR_RESULT_SUCCESS;
}