diff --git a/source/adapters/hip/memory.cpp b/source/adapters/hip/memory.cpp index 41cb2b94d0..950626ff83 100644 --- a/source/adapters/hip/memory.cpp +++ b/source/adapters/hip/memory.cpp @@ -481,11 +481,91 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate( return Result; } -/// \TODO Not implemented -UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t, - ur_image_info_t, size_t, - void *, size_t *) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL urMemImageGetInfo(ur_mem_handle_t hMemory, + ur_image_info_t propName, + size_t propSize, + void *pPropValue, + size_t *pPropSizeRet) { + UR_ASSERT(hMemory->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT); + ScopedContext Active(hMemory->getContext()->getDevice()); + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + + try { + + HIP_ARRAY3D_DESCRIPTOR ArrayInfo; + UR_CHECK_ERROR(hipArray3DGetDescriptor( + &ArrayInfo, std::get(hMemory->Mem).Array)); + + const auto hip2urFormat = + [](hipArray_Format HipFormat) -> ur_image_channel_type_t { + switch (HipFormat) { + case HIP_AD_FORMAT_UNSIGNED_INT8: + return UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT8; + case HIP_AD_FORMAT_UNSIGNED_INT16: + return UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT16; + case HIP_AD_FORMAT_UNSIGNED_INT32: + return UR_IMAGE_CHANNEL_TYPE_UNSIGNED_INT32; + case HIP_AD_FORMAT_SIGNED_INT8: + return UR_IMAGE_CHANNEL_TYPE_SIGNED_INT8; + case HIP_AD_FORMAT_SIGNED_INT16: + return UR_IMAGE_CHANNEL_TYPE_SIGNED_INT16; + case HIP_AD_FORMAT_SIGNED_INT32: + return UR_IMAGE_CHANNEL_TYPE_SIGNED_INT32; + case HIP_AD_FORMAT_HALF: + return UR_IMAGE_CHANNEL_TYPE_HALF_FLOAT; + case HIP_AD_FORMAT_FLOAT: + return UR_IMAGE_CHANNEL_TYPE_FLOAT; + + default: + detail::ur::die("Invalid Hip format specified."); + } + }; + + const auto hipFormatToElementSize = + [](hipArray_Format HipFormat) -> size_t { + switch (HipFormat) { + case HIP_AD_FORMAT_UNSIGNED_INT8: + case HIP_AD_FORMAT_SIGNED_INT8: + return 1; + case HIP_AD_FORMAT_UNSIGNED_INT16: + case HIP_AD_FORMAT_SIGNED_INT16: + case HIP_AD_FORMAT_HALF: + return 2; + case HIP_AD_FORMAT_UNSIGNED_INT32: + case HIP_AD_FORMAT_SIGNED_INT32: + case HIP_AD_FORMAT_FLOAT: + return 4; + default: + detail::ur::die("Invalid Hip format specified."); + } + }; + + switch (propName) { + case UR_IMAGE_INFO_FORMAT: + return ReturnValue(ur_image_format_t{UR_IMAGE_CHANNEL_ORDER_RGBA, + hip2urFormat(ArrayInfo.Format)}); + case UR_IMAGE_INFO_WIDTH: + return ReturnValue(ArrayInfo.Width); + case UR_IMAGE_INFO_HEIGHT: + return ReturnValue(ArrayInfo.Height); + case UR_IMAGE_INFO_DEPTH: + return ReturnValue(ArrayInfo.Depth); + case UR_IMAGE_INFO_ELEMENT_SIZE: + return ReturnValue(hipFormatToElementSize(ArrayInfo.Format)); + case UR_IMAGE_INFO_ROW_PITCH: + case UR_IMAGE_INFO_SLICE_PITCH: + return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; + + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + } catch (ur_result_t Err) { + return Err; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; + } + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) {