diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index 7875650b85..68e3e665d2 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -84,6 +84,62 @@ void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock, --ThreadsPerBlock[0]; } } + +ur_result_t setHipMemAdvise(const void *DevPtr, const size_t Size, + ur_usm_advice_flags_t URAdviceFlags, + hipDevice_t Device) { + // Handle unmapped memory advice flags + if (URAdviceFlags & + (UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY | + UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY | + UR_USM_ADVICE_FLAG_BIAS_CACHED | UR_USM_ADVICE_FLAG_BIAS_UNCACHED)) { + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + using ur_to_hip_advice_t = std::pair; + + static constexpr std::array + URToHIPMemAdviseDeviceFlags{ + std::make_pair(UR_USM_ADVICE_FLAG_SET_READ_MOSTLY, + hipMemAdviseSetReadMostly), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY, + hipMemAdviseUnsetReadMostly), + std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION, + hipMemAdviseSetPreferredLocation), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION, + hipMemAdviseUnsetPreferredLocation), + std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE, + hipMemAdviseSetAccessedBy), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE, + hipMemAdviseUnsetAccessedBy), + }; + for (auto &FlagPair : URToHIPMemAdviseDeviceFlags) { + if (URAdviceFlags & FlagPair.first) { + UR_CHECK_ERROR(hipMemAdvise(DevPtr, Size, FlagPair.second, Device)); + } + } + + static constexpr std::array URToHIPMemAdviseHostFlags{ + std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST, + hipMemAdviseSetPreferredLocation), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST, + hipMemAdviseUnsetPreferredLocation), + std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_HOST, + hipMemAdviseSetAccessedBy), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_HOST, + hipMemAdviseUnsetAccessedBy), + }; + + for (auto &FlagPair : URToHIPMemAdviseHostFlags) { + if (URAdviceFlags & FlagPair.first) { + UR_CHECK_ERROR( + hipMemAdvise(DevPtr, Size, FlagPair.second, hipCpuDeviceId)); + } + } + + return UR_RESULT_SUCCESS; +} + } // namespace UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( @@ -1403,34 +1459,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( ur_queue_handle_t hQueue, const void *pMem, size_t size, ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + std::ignore = flags; + void *HIPDevicePtr = const_cast(pMem); ur_device_handle_t Device = hQueue->getDevice(); - // If the device does not support managed memory access, we can't set - // mem_advise. - if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) { - setErrorMessage("mem_advise ignored as device does not support " - " managed memory access", - UR_RESULT_SUCCESS); - return UR_RESULT_ERROR_ADAPTER_SPECIFIC; - } - - hipPointerAttribute_t attribs; - // TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated - // memory, as it is neither registered as host memory, nor into the address - // space for the current device, meaning the pMem ptr points to a - // system-allocated memory. This means we may need to check system-alloacted - // memory and handle the failure more gracefully. - UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem)); - // async prefetch requires USM pointer (or hip SVM) to work. - if (!attribs.isManaged) { - setErrorMessage("Prefetch hint ignored as prefetch only works with USM", - UR_RESULT_SUCCESS); - return UR_RESULT_ERROR_ADAPTER_SPECIFIC; - } - - // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5, - // so we can't perform this check for such cases. +// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5, +// so we can't perform this check for such cases. #if HIP_VERSION_MAJOR >= 5 unsigned int PointerRangeSize = 0; UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize, @@ -1438,29 +1473,60 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( (hipDeviceptr_t)HIPDevicePtr)); UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE); #endif - // flags is currently unused so fail if set - if (flags != 0) - return UR_RESULT_ERROR_INVALID_VALUE; + ur_result_t Result = UR_RESULT_SUCCESS; - std::unique_ptr EventPtr{nullptr}; try { ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); + + std::unique_ptr EventPtr{nullptr}; + if (phEvent) { EventPtr = std::unique_ptr(ur_event_handle_t_::makeNative( UR_COMMAND_USM_PREFETCH, hQueue, HIPStream)); UR_CHECK_ERROR(EventPtr->start()); } + + // Helper to ensure returning a valid event on early exit. + auto releaseEvent = [&EventPtr, &phEvent]() -> void { + if (phEvent) { + UR_CHECK_ERROR(EventPtr->record()); + *phEvent = EventPtr.release(); + } + }; + + // If the device does not support managed memory access, we can't set + // mem_advise. + if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) { + releaseEvent(); + setErrorMessage("mem_advise ignored as device does not support " + "managed memory access", + UR_RESULT_SUCCESS); + return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + } + + hipPointerAttribute_t attribs; + // TODO: hipPointerGetAttributes will fail if pMem is non-HIP allocated + // memory, as it is neither registered as host memory, nor into the address + // space for the current device, meaning the pMem ptr points to a + // system-allocated memory. This means we may need to check system-alloacted + // memory and handle the failure more gracefully. + UR_CHECK_ERROR(hipPointerGetAttributes(&attribs, pMem)); + // async prefetch requires USM pointer (or hip SVM) to work. + if (!attribs.isManaged) { + releaseEvent(); + setErrorMessage("Prefetch hint ignored as prefetch only works with USM", + UR_RESULT_SUCCESS); + return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + } + UR_CHECK_ERROR( hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream)); - if (phEvent) { - UR_CHECK_ERROR(EventPtr->record()); - *phEvent = EventPtr.release(); - } + releaseEvent(); } catch (ur_result_t Err) { Result = Err; } @@ -1468,22 +1534,109 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( return Result; } +/// USM: memadvise API to govern behavior of automatic migration mechanisms UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, - ur_usm_advice_flags_t, ur_event_handle_t *phEvent) { + ur_usm_advice_flags_t advice, ur_event_handle_t *phEvent) { + UR_ASSERT(pMem && size > 0, UR_RESULT_ERROR_INVALID_VALUE); void *HIPDevicePtr = const_cast(pMem); -// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5, -// so we can't perform this check for such cases. + ur_device_handle_t Device = hQueue->getDevice(); + #if HIP_VERSION_MAJOR >= 5 - unsigned int PointerRangeSize = 0; - UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize, - HIP_POINTER_ATTRIBUTE_RANGE_SIZE, - (hipDeviceptr_t)HIPDevicePtr)); + // NOTE: The hipPointerGetAttribute API is marked as beta, meaning, while this + // is feature complete, it is still open to changes and outstanding issues. + size_t PointerRangeSize = 0; + UR_CHECK_ERROR(hipPointerGetAttribute( + &PointerRangeSize, HIP_POINTER_ATTRIBUTE_RANGE_SIZE, + static_cast(HIPDevicePtr))); UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE); #endif - // TODO implement a mapping to hipMemAdvise once the expected behaviour - // of urEnqueueUSMAdvise is detailed in the USM extension - return urEnqueueEventsWait(hQueue, 0, nullptr, phEvent); + + ur_result_t Result = UR_RESULT_SUCCESS; + + try { + ScopedContext Active(Device); + std::unique_ptr EventPtr{nullptr}; + + if (phEvent) { + EventPtr = + std::unique_ptr(ur_event_handle_t_::makeNative( + UR_COMMAND_USM_ADVISE, hQueue, hQueue->getNextTransferStream())); + EventPtr->start(); + } + + // Helper to ensure returning a valid event on early exit. + auto releaseEvent = [&EventPtr, &phEvent]() -> void { + if (phEvent) { + UR_CHECK_ERROR(EventPtr->record()); + *phEvent = EventPtr.release(); + } + }; + + // If the device does not support managed memory access, we can't set + // mem_advise. + if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) { + releaseEvent(); + setErrorMessage("mem_advise ignored as device does not support " + "managed memory access", + UR_RESULT_SUCCESS); + return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + } + + // Passing MEM_ADVICE_SET/MEM_ADVICE_CLEAR_PREFERRED_LOCATION to + // hipMemAdvise on a GPU device requires the GPU device to report a non-zero + // value for hipDeviceAttributeConcurrentManagedAccess. Therefore, ignore + // the mem advice if concurrent managed memory access is not available. + if (advice & (UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION | + UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION | + UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE | + UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE | + UR_USM_ADVICE_FLAG_DEFAULT)) { + if (!getAttribute(Device, hipDeviceAttributeConcurrentManagedAccess)) { + releaseEvent(); + setErrorMessage("mem_advise ignored as device does not support " + "concurrent managed access", + UR_RESULT_SUCCESS); + return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + } + + // TODO: If pMem points to valid system-allocated pageable memory, we + // should check that the device also has the + // hipDeviceAttributePageableMemoryAccess property, so that a valid + // read-only copy can be created on the device. This also applies for + // UR_USM_MEM_ADVICE_SET/MEM_ADVICE_CLEAR_READ_MOSTLY. + } + + const auto DeviceID = Device->get(); + if (advice & UR_USM_ADVICE_FLAG_DEFAULT) { + UR_CHECK_ERROR( + hipMemAdvise(pMem, size, hipMemAdviseUnsetReadMostly, DeviceID)); + UR_CHECK_ERROR(hipMemAdvise( + pMem, size, hipMemAdviseUnsetPreferredLocation, DeviceID)); + UR_CHECK_ERROR( + hipMemAdvise(pMem, size, hipMemAdviseUnsetAccessedBy, DeviceID)); + } else { + Result = setHipMemAdvise(HIPDevicePtr, size, advice, DeviceID); + // UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid but + // currently unmapped advice arguments as not supported by this platform. + // Therefore, warn the user instead of throwing and aborting the runtime. + if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) { + releaseEvent(); + setErrorMessage("mem_advise is ignored as the advice argument is not " + "supported by this device", + UR_RESULT_SUCCESS); + return UR_RESULT_ERROR_ADAPTER_SPECIFIC; + } + } + + releaseEvent(); + } catch (ur_result_t err) { + Result = err; + } catch (...) { + Result = UR_RESULT_ERROR_UNKNOWN; + } + + return Result; } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(