Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 196 additions & 43 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,62 @@ void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock,
--ThreadsPerBlock[0];
}
}

ur_result_t setHipMemAdvise(const void *DevPtr, const size_t Size,
ur_usm_advice_flags_t URAdviceFlags,
hipDevice_t Device) {
// Handle unmapped memory advice flags
if (URAdviceFlags &
(UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY |
UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY |
UR_USM_ADVICE_FLAG_BIAS_CACHED | UR_USM_ADVICE_FLAG_BIAS_UNCACHED)) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

using ur_to_hip_advice_t = std::pair<ur_usm_advice_flags_t, hipMemoryAdvise>;

static constexpr std::array<ur_to_hip_advice_t, 6>
URToHIPMemAdviseDeviceFlags{
std::make_pair(UR_USM_ADVICE_FLAG_SET_READ_MOSTLY,
hipMemAdviseSetReadMostly),
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY,
hipMemAdviseUnsetReadMostly),
std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION,
hipMemAdviseSetPreferredLocation),
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION,
hipMemAdviseUnsetPreferredLocation),
std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE,
hipMemAdviseSetAccessedBy),
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE,
hipMemAdviseUnsetAccessedBy),
};
for (auto &FlagPair : URToHIPMemAdviseDeviceFlags) {
if (URAdviceFlags & FlagPair.first) {
UR_CHECK_ERROR(hipMemAdvise(DevPtr, Size, FlagPair.second, Device));
}
}

static constexpr std::array<ur_to_hip_advice_t, 4> URToHIPMemAdviseHostFlags{
std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST,
hipMemAdviseSetPreferredLocation),
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST,
hipMemAdviseUnsetPreferredLocation),
std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_HOST,
hipMemAdviseSetAccessedBy),
std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_HOST,
hipMemAdviseUnsetAccessedBy),
};

for (auto &FlagPair : URToHIPMemAdviseHostFlags) {
if (URAdviceFlags & FlagPair.first) {
UR_CHECK_ERROR(
hipMemAdvise(DevPtr, Size, FlagPair.second, hipCpuDeviceId));
}
}

return UR_RESULT_SUCCESS;
}

} // namespace

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
Expand Down Expand Up @@ -1403,87 +1459,184 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
std::ignore = flags;

void *HIPDevicePtr = const_cast<void *>(pMem);
ur_device_handle_t Device = hQueue->getDevice();

// If the device does not support managed memory access, we can't set
// mem_advise.
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
setErrorMessage("mem_advise ignored as device does not support "
" managed memory access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

hipPointerAttribute_t attribs;
// TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
// memory, as it is neither registered as host memory, nor into the address
// space for the current device, meaning the pMem ptr points to a
// system-allocated memory. This means we may need to check system-alloacted
// memory and handle the failure more gracefully.
UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem));
// async prefetch requires USM pointer (or hip SVM) to work.
if (!attribs.isManaged) {
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
// so we can't perform this check for such cases.
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
// so we can't perform this check for such cases.
#if HIP_VERSION_MAJOR >= 5
unsigned int PointerRangeSize = 0;
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
(hipDeviceptr_t)HIPDevicePtr));
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
#endif
// flags is currently unused so fail if set
if (flags != 0)
return UR_RESULT_ERROR_INVALID_VALUE;

ur_result_t Result = UR_RESULT_SUCCESS;
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);

std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

if (phEvent) {
EventPtr =
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
UR_CHECK_ERROR(EventPtr->start());
}

// Helper to ensure returning a valid event on early exit.
auto releaseEvent = [&EventPtr, &phEvent]() -> void {
if (phEvent) {
UR_CHECK_ERROR(EventPtr->record());
*phEvent = EventPtr.release();
}
};

// If the device does not support managed memory access, we can't set
// mem_advise.
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
releaseEvent();
setErrorMessage("mem_advise ignored as device does not support "
"managed memory access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

hipPointerAttribute_t attribs;
// TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated
// memory, as it is neither registered as host memory, nor into the address
// space for the current device, meaning the pMem ptr points to a
// system-allocated memory. This means we may need to check system-alloacted
// memory and handle the failure more gracefully.
UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem));
// async prefetch requires USM pointer (or hip SVM) to work.
if (!attribs.isManaged) {
releaseEvent();
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

UR_CHECK_ERROR(
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
if (phEvent) {
UR_CHECK_ERROR(EventPtr->record());
*phEvent = EventPtr.release();
}
releaseEvent();
} catch (ur_result_t Err) {
Result = Err;
}

return Result;
}

/// USM: memadvise API to govern behavior of automatic migration mechanisms
UR_APIEXPORT ur_result_t UR_APICALL
urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_advice_flags_t, ur_event_handle_t *phEvent) {
ur_usm_advice_flags_t advice, ur_event_handle_t *phEvent) {
UR_ASSERT(pMem && size > 0, UR_RESULT_ERROR_INVALID_VALUE);
void *HIPDevicePtr = const_cast<void *>(pMem);
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
// so we can't perform this check for such cases.
ur_device_handle_t Device = hQueue->getDevice();

#if HIP_VERSION_MAJOR >= 5
unsigned int PointerRangeSize = 0;
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
(hipDeviceptr_t)HIPDevicePtr));
// NOTE: The hipPointerGetAttribute API is marked as beta, meaning, while this
// is feature complete, it is still open to changes and outstanding issues.
size_t PointerRangeSize = 0;
UR_CHECK_ERROR(hipPointerGetAttribute(
&PointerRangeSize, HIP_POINTER_ATTRIBUTE_RANGE_SIZE,
static_cast<hipDeviceptr_t>(HIPDevicePtr)));
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
#endif
// TODO implement a mapping to hipMemAdvise once the expected behaviour
// of urEnqueueUSMAdvise is detailed in the USM extension
return urEnqueueEventsWait(hQueue, 0, nullptr, phEvent);

ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(Device);
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

if (phEvent) {
EventPtr =
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream()));
EventPtr->start();
}

// Helper to ensure returning a valid event on early exit.
auto releaseEvent = [&EventPtr, &phEvent]() -> void {
if (phEvent) {
UR_CHECK_ERROR(EventPtr->record());
*phEvent = EventPtr.release();
}
};

// If the device does not support managed memory access, we can't set
// mem_advise.
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
releaseEvent();
setErrorMessage("mem_advise ignored as device does not support "
"managed memory access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

// Passing MEM_ADVICE_SET/MEM_ADVICE_CLEAR_PREFERRED_LOCATION to
// hipMemAdvise on a GPU device requires the GPU device to report a non-zero
// value for hipDeviceAttributeConcurrentManagedAccess. Therefore, ignore
// the mem advice if concurrent managed memory access is not available.
if (advice & (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION |
UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION |
UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE |
UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE |
UR_USM_ADVICE_FLAG_DEFAULT)) {
if (!getAttribute(Device, hipDeviceAttributeConcurrentManagedAccess)) {
releaseEvent();
setErrorMessage("mem_advise ignored as device does not support "
"concurrent managed access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

// TODO: If pMem points to valid system-allocated pageable memory, we
// should check that the device also has the
// hipDeviceAttributePageableMemoryAccess property, so that a valid
// read-only copy can be created on the device. This also applies for
// UR_USM_MEM_ADVICE_SET/MEM_ADVICE_CLEAR_READ_MOSTLY.
}

const auto DeviceID = Device->get();
if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
UR_CHECK_ERROR(
hipMemAdvise(pMem, size, hipMemAdviseUnsetReadMostly, DeviceID));
UR_CHECK_ERROR(hipMemAdvise(
pMem, size, hipMemAdviseUnsetPreferredLocation, DeviceID));
UR_CHECK_ERROR(
hipMemAdvise(pMem, size, hipMemAdviseUnsetAccessedBy, DeviceID));
} else {
Result = setHipMemAdvise(HIPDevicePtr, size, advice, DeviceID);
// UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid but
// currently unmapped advice arguments as not supported by this platform.
// Therefore, warn the user instead of throwing and aborting the runtime.
if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) {
releaseEvent();
setErrorMessage("mem_advise is ignored as the advice argument is not "
"supported by this device",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
}

releaseEvent();
} catch (ur_result_t err) {
Result = err;
} catch (...) {
Result = UR_RESULT_ERROR_UNKNOWN;
}

return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(
Expand Down