diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 57b332ce65b93..523d2a9d1a8be 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -52,6 +52,7 @@ constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; constexpr const char* OpenVINO_RT = "OpenVINO_RT"; constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; +constexpr const char* QNN_HTP_SHARED = "QnnHtpShared"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; constexpr const char* WEBNN_TENSOR = "WebNN_Tensor"; @@ -81,6 +82,10 @@ class IAllocator { */ virtual void* Alloc(size_t size) = 0; + /** + * Free memory at p. + * If p is nullptr, do nothing. + */ virtual void Free(void* p) = 0; // Reserve() is an interface exposed for an implementation of IAllocator diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index 6f658ab65be20..adade482f6a17 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -25,6 +25,7 @@ struct OrtDevice { static const MemoryType CUDA_PINNED = 1; static const MemoryType HIP_PINNED = 2; static const MemoryType CANN_PINNED = 3; + static const MemoryType QNN_HTP_SHARED = 4; }; constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) diff --git a/include/onnxruntime/core/framework/ortmemoryinfo.h b/include/onnxruntime/core/framework/ortmemoryinfo.h index 7af5554e25c0b..d060c6546ae27 100644 --- a/include/onnxruntime/core/framework/ortmemoryinfo.h +++ b/include/onnxruntime/core/framework/ortmemoryinfo.h @@ -6,6 +6,7 @@ #include #include "core/common/hash_combine.h" +#include "core/framework/ortdevice.h" struct OrtMemoryInfo { OrtMemoryInfo() = default; // to allow default construction of Tensor diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index f3e9758766d00..0a57999246b06 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2130,10 +2130,10 @@ struct KernelContext { explicit KernelContext(OrtKernelContext* context); size_t GetInputCount() const; size_t GetOutputCount() const; - // If input is optional and is not present, the method returns en empty ConstValue + // If input is optional and is not present, the method returns an empty ConstValue // which can be compared to nullptr. ConstValue GetInput(size_t index) const; - // If outout is optional and is not present, the method returns en empty UnownedValue + // If outout is optional and is not present, the method returns an empty UnownedValue // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 26b98b0a04d24..02dbb3e518783 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -155,11 +155,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( - onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), + onnxruntime::CUDA_PINNED, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), id1, mem_type1); } else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) { *out = new OrtMemoryInfo( - onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(id1)), + onnxruntime::HIP_PINNED, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(id1)), + id1, mem_type1); + } else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) { + *out = new OrtMemoryInfo( + onnxruntime::QNN_HTP_SHARED, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast(id1)), id1, mem_type1); } else { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported."); diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0ac2271ba09f1..55907b64afa63 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -100,7 +100,7 @@ SessionState::SessionState(Graph& graph, for (auto& ep : execution_providers_) { auto allocators = ep->CreatePreferredAllocators(); for (auto& alloc : allocators) { - allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key + allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key } } } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 3af646c3ce13a..430bc04bc29f3 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -7,20 +7,22 @@ #include #include #include "QnnOpDef.h" -#include "HTP/QnnHtpPerfInfrastructure.h" -#include "HTP/QnnHtpSystemContext.h" #include "CPU/QnnCpuCommon.h" // TODO: not exist for Windows yet // #include "GPU/QnnGpuCommon.h" #include "DSP/QnnDspCommon.h" #include "HTP/QnnHtpCommon.h" #include "HTP/QnnHtpContext.h" +#include "HTP/QnnHtpPerfInfrastructure.h" +#include "HTP/QnnHtpSystemContext.h" #include "Saver/QnnSaver.h" #include #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" +#include "core/providers/qnn/qnn_allocator.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "core/providers/qnn/builder/qnn_utils.h" #ifdef _WIN32 #include @@ -550,10 +552,11 @@ Status QnnBackendManager::CreateContext() { device_handle_, context_configs, &context); - contexts_.push_back(context); ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result)); + ORT_RETURN_IF_ERROR(AddQnnContext(context)); + context_created_ = true; return Status::OK(); } @@ -563,6 +566,9 @@ Status QnnBackendManager::ReleaseContext() { return Status::OK(); } + // release context mem handles + context_mem_handles_.clear(); + bool failed = false; for (auto context : contexts_) { Qnn_ErrorHandle_t result = qnn_interface_.contextFree(context, nullptr); @@ -771,7 +777,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t &context, profile_backend_handle_); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); - contexts_.push_back(context); + ORT_RETURN_IF_ERROR(AddQnnContext(context)); if (1 == graph_count) { // in case the EPContext node is generated from script // the graph name from the context binary may not match the EPContext node name @@ -1413,12 +1419,7 @@ const char* QnnBackendManager::QnnProfileErrorToString(QnnProfile_Error_t error) } const char* QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) { - // From QNN SDK: The memory is statically owned and should not be freed by the caller. - const char* error_msg = nullptr; - if (QNN_SUCCESS == qnn_interface_.errorGetMessage(error, &error_msg)) { - return error_msg; - } - return "Unknown"; + return utils::GetQnnErrorMessage(qnn_interface_, error); } const std::string QnnBackendManager::ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) { @@ -1651,5 +1652,63 @@ void* QnnBackendManager::LibFunction(void* handle, const char* symbol, std::stri #endif } +Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) { + ORT_RETURN_IF(logger_ == nullptr, "logger_ should be set."); + + auto mem_handle_manager = std::make_shared(GetQnnInterface(), context, *logger_); + const bool inserted = context_mem_handles_.try_emplace(context, std::move(mem_handle_manager)).second; + ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", context); + + contexts_.push_back(context); + + return Status::OK(); +} + +Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address, + const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& mem_handle) { + const auto context_mem_handles_it = context_mem_handles_.find(context); + ORT_RETURN_IF_NOT(context_mem_handles_it != context_mem_handles_.end(), "QNN context not found: ", context); + + auto& context_mem_handle_manager = context_mem_handles_it->second; + bool did_register{}; + ORT_RETURN_IF_ERROR(context_mem_handle_manager->GetOrRegister(shared_memory_address, qnn_tensor, + mem_handle, did_register)); + + if (did_register) { + HtpSharedMemoryAllocator::AllocationCleanUpFn allocation_clean_up = + [&logger = *logger_, + weak_backend_manager = weak_from_this(), + weak_context_mem_handle_manager = std::weak_ptr{context_mem_handle_manager}]( + void* shared_memory_address) { + // get QnnBackendManager shared_ptr to ensure that qnn_interface is still valid + auto backend_manager = weak_backend_manager.lock(); + if (!backend_manager) { + return; + } + + auto context_mem_handle_manager = weak_context_mem_handle_manager.lock(); + if (!context_mem_handle_manager) { + return; + } + + // TODO should also ensure that the QNN context handle is still valid. + // This *should* be true as long as the QNN contexts are not freed from anywhere other than + // ~QnnBackendManager(). If we are able to lock weak_backend_manager, we haven't gotten to the dtor yet. + + auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address); + if (!unregister_status.IsOK()) { + LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: " + << shared_memory_address << ", error: " << unregister_status.ErrorMessage(); + } + }; + + ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::AddAllocationCleanUp(shared_memory_address, + std::move(allocation_clean_up))); + } + + return Status::OK(); +} + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b145f2a2cd724..cddeffd21f32e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -24,6 +24,7 @@ #include "core/common/status.h" #include "core/common/logging/logging.h" #include "core/common/path_string.h" +#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { @@ -31,7 +32,7 @@ namespace qnn { class QnnModel; -class QnnBackendManager { +class QnnBackendManager : public std::enable_shared_from_this { public: QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level_etw, @@ -170,6 +171,10 @@ class QnnBackendManager { uint64_t buffer_length, uint64_t& max_spill_fill_buffer_size); + Status GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address, + const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& mem_handle); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -240,6 +245,9 @@ class QnnBackendManager { const char* eventIdentifier); #endif + Status AddQnnContext(Qnn_ContextHandle_t context); + Status ReleaseQnnContextMemHandles(); + private: const std::string backend_path_; std::mutex logger_mutex_; @@ -253,6 +261,9 @@ class QnnBackendManager { Qnn_LogHandle_t log_handle_ = nullptr; Qnn_DeviceHandle_t device_handle_ = nullptr; std::vector contexts_; + // Note: Using shared_ptr so that we can refer to it with a weak_ptr from a + // HtpSharedMemoryAllocator allocation cleanup callback. + std::unordered_map> context_mem_handles_; ProfilingLevel profiling_level_etw_; ProfilingLevel profiling_level_; ProfilingLevel profiling_level_merge_; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc new file mode 100644 index 0000000000000..73d433942b575 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" + +#include "HTP/QnnHtpMem.h" + +#include "core/common/common.h" +#include "core/providers/qnn/builder/qnn_def.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/qnn_allocator.h" + +namespace onnxruntime::qnn { + +QnnContextMemHandleManager::QnnContextMemHandleManager(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ContextHandle_t context, + const logging::Logger& logger) + : qnn_interface_{qnn_interface}, + context_{context}, + logger_{logger} { +} + +QnnContextMemHandleManager::~QnnContextMemHandleManager() { + Clear(); +} + +Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& qnn_mem_handle, bool& did_register) { + const auto qnn_tensor_rank = GetQnnTensorRank(qnn_tensor); + auto* const qnn_tensor_dims = GetQnnTensorDims(qnn_tensor); + const auto qnn_tensor_data_type = GetQnnTensorDataType(qnn_tensor); + + const size_t qnn_tensor_data_size = + utils::GetQnnTensorDataSize(gsl::span{qnn_tensor_dims, size_t{qnn_tensor_rank}}, qnn_tensor_data_type); + + { + std::scoped_lock g{mem_handles_mutex_}; + + // find existing mem handle + if (const auto mem_handles_it = mem_handles_.find(shared_memory_address); + mem_handles_it != mem_handles_.end()) { + const auto& mem_handle_record = mem_handles_it->second; + + // check that actual tensor size is less than or equal to registered tensor size + ORT_RETURN_IF_NOT(qnn_tensor_data_size <= mem_handle_record.registered_tensor_data_size, + "Actual tensor data size (", qnn_tensor_data_size, + ") is larger than registered tensor data size (", mem_handle_record.registered_tensor_data_size, + ")."); + + qnn_mem_handle = mem_handle_record.mem_handle.get(); + did_register = false; + return Status::OK(); + } + + // register a new mem handle + HtpSharedMemoryAllocator::SharedMemoryInfo shared_memory_info{}; + ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(shared_memory_address, + shared_memory_info)); + + Qnn_MemDescriptor_t mem_descriptor{}; + mem_descriptor.memShape.dimSize = qnn_tensor_dims; + mem_descriptor.memShape.numDim = qnn_tensor_rank; + mem_descriptor.memShape.shapeConfig = nullptr; + mem_descriptor.dataType = qnn_tensor_data_type; + mem_descriptor.memType = QNN_MEM_TYPE_CUSTOM; + + QnnMemHtp_Descriptor_t htp_mem_descriptor{}; + htp_mem_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; + htp_mem_descriptor.size = shared_memory_info.total_size; + htp_mem_descriptor.sharedBufferConfig.fd = shared_memory_info.fd; + htp_mem_descriptor.sharedBufferConfig.offset = shared_memory_info.offset; + + mem_descriptor.customInfo = &htp_mem_descriptor; + + LOGS(logger_, VERBOSE) << "Registering QNN mem handle for context: " << context_ + << ", shared memory (address: " << shared_memory_address + << ", offset: " << shared_memory_info.offset + << ", fd: " << shared_memory_info.fd + << ")"; + + Qnn_MemHandle_t raw_mem_handle{}; + const auto register_result = qnn_interface_.memRegister(context_, &mem_descriptor, 1, &raw_mem_handle); + ORT_RETURN_IF_NOT(register_result == QNN_SUCCESS, + "qnn_interface.memRegister() failed: ", + utils::GetVerboseQnnErrorMessage(qnn_interface_, register_result)); + + LOGS(logger_, VERBOSE) << "Registered QNN mem handle. mem_handle: " << raw_mem_handle; + + const auto unregister_mem_handle = [this](Qnn_MemHandle_t raw_mem_handle) { + LOGS(logger_, VERBOSE) << "Unregistering QNN mem handle. mem_handle: " << raw_mem_handle; + + const auto unregister_result = qnn_interface_.memDeRegister(&raw_mem_handle, 1); + if (unregister_result != QNN_SUCCESS) { + LOGS(logger_, ERROR) << "qnn_interface.memDeRegister() failed: " + << utils::GetVerboseQnnErrorMessage(qnn_interface_, unregister_result); + } + }; + + UniqueQnnMemHandle mem_handle(raw_mem_handle, unregister_mem_handle); + MemHandleRecord mem_handle_record{qnn_tensor_data_size, std::move(mem_handle)}; + mem_handles_.emplace(shared_memory_address, std::move(mem_handle_record)); + + qnn_mem_handle = raw_mem_handle; + did_register = true; + return Status::OK(); + } +} + +Status QnnContextMemHandleManager::Unregister(void* shared_memory_address) { + std::scoped_lock g{mem_handles_mutex_}; + + auto mem_handles_it = mem_handles_.find(shared_memory_address); + ORT_RETURN_IF_NOT(mem_handles_it != mem_handles_.end(), + "No mem handle found for address (", shared_memory_address, ")."); + + mem_handles_.erase(mem_handles_it); + + return Status::OK(); +} + +void QnnContextMemHandleManager::Clear() { + std::scoped_lock g{mem_handles_mutex_}; + mem_handles_.clear(); +} + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h new file mode 100644 index 0000000000000..acb33d7175061 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "QnnInterface.h" + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +#include "core/common/status.h" + +namespace onnxruntime::qnn { + +// This class manages QNN mem handles (Qnn_MemHandle_t) associated with a QNN context (Qnn_ContextHandle_t). +// In particular, it handles the registration and deregistration of mem handles. +// The associated QNN context is expected to be in scope for the lifetime of the QnnContextMemHandleManager. +class QnnContextMemHandleManager { + public: + QnnContextMemHandleManager(const QNN_INTERFACE_VER_TYPE& qnn_interface, Qnn_ContextHandle_t qnn_context, + const logging::Logger& logger); + + ~QnnContextMemHandleManager(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnContextMemHandleManager); + + Status GetOrRegister(void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& qnn_mem_handle, bool& did_register); + + Status Unregister(void* shared_memory_address); + + void Clear(); + + private: + const QNN_INTERFACE_VER_TYPE& qnn_interface_; + Qnn_ContextHandle_t context_; + const logging::Logger& logger_; + + // assume Qnn_MemHandle_t is a pointer and able to be wrapped with std::unique_ptr + static_assert(std::is_pointer_v); + + using UniqueQnnMemHandle = + std::unique_ptr, std::function>; + + struct MemHandleRecord { + size_t registered_tensor_data_size; + UniqueQnnMemHandle mem_handle; + }; + + // shared memory address -> associated mem handle record + InlinedHashMap mem_handles_; + std::mutex mem_handles_mutex_; // synchronize access to mem_handles_ +}; + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index c0fc079979822..5af7f024716f1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -208,6 +208,22 @@ void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data) ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); } +void SetQnnTensorMemHandle(Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t mem_handle) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.memHandle = mem_handle; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.memHandle = mem_handle; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params) { if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { qnn_tensor.v1.quantizeParams = quantize_params; @@ -350,6 +366,20 @@ const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor) ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); } +Qnn_MemHandle_t GetQnnTensorMemHandle(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.memHandle; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.memHandle; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor) { if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { return qnn_tensor.v1.quantizeParams; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index ffd2dc9b11010..b3b6b392d7857 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -105,6 +105,7 @@ void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size); void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size); void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data); +void SetQnnTensorMemHandle(Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t mem_handle); void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params); bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface, const Qnn_GraphHandle_t& graph, @@ -123,6 +124,7 @@ Qnn_TensorMemType_t GetQnnTensorMemType(const Qnn_Tensor_t& qnn_tensor); uint32_t GetQnnTensorRank(const Qnn_Tensor_t& qnn_tensor); uint32_t* GetQnnTensorDims(const Qnn_Tensor_t& qnn_tensor); const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor); +Qnn_MemHandle_t GetQnnTensorMemHandle(const Qnn_Tensor_t& qnn_tensor); const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor); /** @@ -465,11 +467,13 @@ class QnnOpProperty { class GraphInfo { public: - GraphInfo(const Qnn_GraphHandle_t graph, + GraphInfo(Qnn_GraphHandle_t graph, const std::string& name, + Qnn_ContextHandle_t graph_context, std::vector&& input_tensors, std::vector&& output_tensors) : graph_name_(name), graph_(graph), + graph_context_(graph_context), input_tensors_(std::move(input_tensors)), output_tensors_(std::move(output_tensors)) { } @@ -479,12 +483,15 @@ class GraphInfo { const std::string& Name() const { return graph_name_; } const std::vector& InputTensors() const { return input_tensors_; } const std::vector& OutputTensors() const { return output_tensors_; } - const Qnn_GraphHandle_t& Graph() const { return graph_; } + Qnn_GraphHandle_t Graph() const { return graph_; } + Qnn_ContextHandle_t GraphContext() const { return graph_context_; } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphInfo); private: std::string graph_name_; Qnn_GraphHandle_t graph_; + // QNN context that holds the QNN graph referenced by `graph_` + Qnn_ContextHandle_t graph_context_; std::vector input_tensors_; std::vector output_tensors_; }; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 4f73e4c532ed4..0bbb046605604 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -4,30 +4,30 @@ #include "qnn_model.h" #include +#include #include "QnnOpDef.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_node_group.h" -#include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/qnn_allocator.h" +#include "core/providers/qnn/shared_context.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace qnn { -bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger) { +bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& /* logger */) { bool rt = true; graph_info_ = std::make_unique(model_wrapper.GetQnnGraph(), model_wrapper.GetQnnGraphName(), + model_wrapper.GetQnnGraphContext(), std::move(model_wrapper.GetGraphInputTensorWrappers()), std::move(model_wrapper.GetGraphOutputTensorWrappers())); - if (graph_info_ == nullptr) { - LOGS(logger, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo."; - return false; - } return rt; } @@ -185,7 +185,33 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { return Status::OK(); } -Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger) { +static Status BindQnnTensorMemoryToOrtValue(const logging::Logger& logger, + QnnBackendManager& qnn_backend_manager, + const OrtMemoryInfo& ort_value_memory_info, + void* ort_value_data, uint32_t ort_value_data_size, + Qnn_ContextHandle_t qnn_context, + Qnn_Tensor_t& qnn_tensor) { + // either set qnn_tensor memHandle or clientBuf + const bool uses_shared_memory = ort_value_memory_info == HtpSharedMemoryAllocator::AssociatedMemoryInfo(); + + if (!uses_shared_memory) { + LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t clientBuf to ORT tensor memory."; + SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_RAW); + SetQnnTensorClientBuf(qnn_tensor, ort_value_data, ort_value_data_size); + } else { + LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t memHandle to ORT tensor shared memory."; + Qnn_MemHandle_t qnn_mem_handle{}; + ORT_RETURN_IF_ERROR(qnn_backend_manager.GetOrRegisterContextMemHandle(qnn_context, ort_value_data, qnn_tensor, + qnn_mem_handle)); + SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); + SetQnnTensorMemHandle(qnn_tensor, qnn_mem_handle); + } + + return Status::OK(); +} + +Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, + const logging::Logger& logger) { LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs"; const size_t num_inputs = context.GetInputCount(); const size_t num_outputs = context.GetOutputCount(); @@ -193,7 +219,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging:: ORT_RETURN_IF_NOT(qnn_output_infos_.size() == num_outputs, "Inconsistent output sizes"); using namespace qnn::utils; - auto TensorDataSize = [&](auto ort_tensor) -> size_t { + auto TensorDataSize = [](auto ort_tensor) -> size_t { auto tensor_type_and_shape = ort_tensor.GetTensorTypeAndShapeInfo(); size_t length = tensor_type_and_shape.GetElementCount(); ONNXTensorElementDataType element_type = tensor_type_and_shape.GetElementType(); @@ -210,13 +236,19 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging:: auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); auto ort_tensor_size = TensorDataSize(ort_input_tensor); LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; + << " Ort tensor size: " << ort_tensor_size; ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size."); qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); - SetQnnTensorClientBuf(qnn_inputs.back(), - const_cast(ort_input_tensor.GetTensorData()), qnn_input_info.tensor_byte_size); + + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + logger, + *qnn_backend_manager_, + *static_cast(ort_input_tensor.GetTensorMemoryInfo()), + const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, + graph_info_->GraphContext(), + qnn_inputs.back())); } std::vector qnn_outputs; @@ -230,24 +262,30 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging:: auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); auto ort_tensor_size = TensorDataSize(ort_output_tensor); LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; + << " Ort tensor size: " << ort_tensor_size; ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size"); qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); - SetQnnTensorClientBuf(qnn_outputs.back(), - const_cast(ort_output_tensor.GetTensorData()), qnn_output_info.tensor_byte_size); + + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + logger, + *qnn_backend_manager_, + *static_cast(ort_output_tensor.GetTensorMemoryInfo()), + const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, + graph_info_->GraphContext(), + qnn_outputs.back())); } - LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); - auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); - auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; - { - // Acquire mutex before calling graphExecute and profiling APIs to support calling session.Run() - // from multiple threads. + const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); + + // Acquire mutex before calling QNN APIs to support calling session.Run() from multiple threads. std::lock_guard lock(graph_exec_mutex_); + + LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); + auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); execute_status = qnn_interface.graphExecute(graph_info_->Graph(), qnn_inputs.data(), static_cast(qnn_inputs.size()), @@ -275,20 +313,6 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging:: return Status::OK(); } -Status QnnModel::GetQnnTensorDataLength(const std::vector& dims, - Qnn_DataType_t data_type, - size_t& data_length) const { - ORT_RETURN_IF(dims.empty(), "Tensor dimensions is nullptr"); - - data_length = utils::GetElementSizeByType(data_type); - - for (size_t r = 0; r < dims.size(); r++) { - data_length *= dims[r]; - } - - return Status::OK(); -} - // Setup information for Qnn inputs/outputs used during execution. Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, const std::vector& tensor_wrappers, @@ -298,11 +322,8 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, qnn_tensor_infos.resize(tensor_count); for (auto& tensor_wrapper : tensor_wrappers) { - size_t length = 0; - using namespace qnn::utils; - ORT_RETURN_IF_ERROR(GetQnnTensorDataLength(tensor_wrapper.GetTensorDims(), - tensor_wrapper.GetTensorDataType(), - length)); + const size_t length = utils::GetQnnTensorDataSize(tensor_wrapper.GetTensorDims(), + tensor_wrapper.GetTensorDataType()); const auto& tensor_name = tensor_wrapper.GetName(); auto qnn_index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name); auto ort_index = is_input ? GetOrtInputIndex(tensor_name) : qnn_index; @@ -379,9 +400,9 @@ Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_Graph graph_info_ = std::make_unique(graph, graph_name, + context, std::move(input_tensor_wrappers), std::move(output_tensor_wrappers)); - ORT_RETURN_IF(graph_info_ == nullptr, "Failed to allocate GraphInfo"); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 2e0935391ca78..2f220e708c50e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -3,15 +3,16 @@ #pragma once +#include #include #include "core/common/status.h" #include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" -#include #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" +#include "core/providers/qnn/rpcmem_library.h" #include "core/session/onnxruntime_cxx_api.h" namespace onnxruntime { @@ -43,7 +44,8 @@ class QnnModel { Status SetupQnnInputOutput(const logging::Logger& logger); - Status ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger); + Status ExecuteGraph(const Ort::KernelContext& context, + const logging::Logger& logger); const OnnxTensorInfo* GetOutputInfo(const std::string& name) const { auto it = outputs_info_.find(name); @@ -111,10 +113,6 @@ class QnnModel { const std::unordered_map& node_unit_map) const; bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger); - Status GetQnnTensorDataLength(const std::vector& dims, - Qnn_DataType_t data_type, - size_t& data_length) const; - Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers, bool is_input = true); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 2c7f3c8b22ddd..c2e3e9516150f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -30,21 +30,23 @@ bool QnnModelWrapper::CreateQnnGraph(const Qnn_ContextHandle_t& context, return false; } if (graph_name.length() == 0) { - LOGS(logger_, ERROR) << "Empty grpah name."; + LOGS(logger_, ERROR) << "Empty graph name."; return false; } - graph_name_ = graph_name; - auto rt = qnn_interface_.graphCreate(context, graph_name_.c_str(), graph_configs, &graph_); + auto rt = qnn_interface_.graphCreate(context, graph_name.c_str(), graph_configs, &graph_); if (rt != QNN_GRAPH_NO_ERROR || graph_ == nullptr) { - rt = qnn_interface_.graphRetrieve(context, graph_name_.c_str(), &graph_); + rt = qnn_interface_.graphRetrieve(context, graph_name.c_str(), &graph_); if (rt != QNN_GRAPH_NO_ERROR || graph_ == nullptr) { LOGS(logger_, ERROR) << "Failed to create Qnn graph: " << graph_name; return false; } } + LOGS(logger_, VERBOSE) << "Created Qnn graph: " << graph_name; + graph_name_ = graph_name; + graph_context_ = context; return true; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index f3e52050e79e0..6e165a5f95afe 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -93,10 +93,12 @@ class QnnModelWrapper { bool ComposeQnnGraph(); - Qnn_GraphHandle_t GetQnnGraph() { return graph_; } + Qnn_GraphHandle_t GetQnnGraph() const { return graph_; } std::string GetQnnGraphName() const { return graph_name_; } + Qnn_ContextHandle_t GetQnnGraphContext() const { return graph_context_; } + // Move input tensor wrappers to GraphInfo, QnnModelWrapper end of live std::vector&& GetGraphInputTensorWrappers() { GetGraphInputOutputTensorWrapper(model_input_names_, model_input_tensor_wrappers_); @@ -270,6 +272,8 @@ class QnnModelWrapper { const Qnn_BackendHandle_t& backend_handle_; Qnn_GraphHandle_t graph_ = nullptr; std::string graph_name_ = ""; + // QNN context that holds the QNN graph referenced by `graph_` + Qnn_ContextHandle_t graph_context_ = nullptr; std::vector model_input_names_; std::vector model_output_names_; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 8d2cb5bdb6da0..ad6f48a6d2c48 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -1,15 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/qnn/builder/qnn_utils.h" + #include +#include #include #include #include -#include #include "core/common/common.h" +#include "core/common/safeint.h" #include "core/framework/data_types.h" -#include "qnn_utils.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { @@ -63,6 +65,12 @@ size_t GetElementSizeByType(ONNXTensorElementDataType elem_type) { return pos->second; } +size_t GetQnnTensorDataSize(gsl::span shape, Qnn_DataType_t element_type) { + ORT_ENFORCE(!shape.empty(), "Empty shape not allowed."); // TODO can we just treat empty shape as a scalar? + SafeInt data_length = GetElementSizeByType(element_type); + return std::accumulate(shape.begin(), shape.end(), data_length, std::multiplies<>{}); +} + std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { switch (scalar.dataType) { case QNN_DATATYPE_INT_8: @@ -570,6 +578,27 @@ Status Quantize(const double double_value, return Status::OK(); } +const char* GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, Qnn_ErrorHandle_t qnn_error_handle) { + // From QNN SDK: The memory is statically owned and should not be freed by the caller. + const char* error_msg = nullptr; + if (qnn_interface.errorGetMessage(qnn_error_handle, &error_msg) == QNN_SUCCESS) { + return error_msg; + } + return "Unknown error."; +} + +std::string GetVerboseQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ErrorHandle_t qnn_error_handle) { + const char* error_msg = nullptr; + if (qnn_interface.errorGetVerboseMessage(qnn_error_handle, &error_msg) == QNN_SUCCESS) { + auto free_error_msg = gsl::finally([&qnn_interface, error_msg] { + qnn_interface.errorFreeVerboseMessage(error_msg); + }); + return error_msg; + } + return "Unknown error."; +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index aa4a27460563f..e07ee64ce33bd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -8,7 +8,11 @@ #include #include +#include + +#include "QnnInterface.h" #include "QnnTypes.h" + #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/node_unit.h" #include "core/util/qmath.h" @@ -22,6 +26,9 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type); size_t GetElementSizeByType(ONNXTensorElementDataType elem_type); +// Gets tensor data size in bytes. +size_t GetQnnTensorDataSize(gsl::span shape, Qnn_DataType_t element_data_type); + // TODO: make these work with Wrappers? std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param); std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor); @@ -104,6 +111,14 @@ Status Quantize(const double double_value, const Qnn_DataType_t qnn_data_type, int& quant_value); +// Gets error message associated with QNN error handle value. +const char* GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ErrorHandle_t qnn_error_handle); + +// Gets verbose error message associated with QNN error handle value. +std::string GetVerboseQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ErrorHandle_t qnn_error_handle); + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc new file mode 100644 index 0000000000000..84d67615ff3c9 --- /dev/null +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -0,0 +1,231 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/qnn_allocator.h" + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/mlas/inc/mlas.h" // for MlasGetPreferredBufferAlignment() + +namespace onnxruntime::qnn { + +namespace { + +struct AllocationHeader { + static constexpr std::array kAllocationHeaderMarker{'o', 'r', 't', 'a', 'l', 'l', 'o', 'c'}; + + // Marker bytes to verify as a sanity check. + std::array marker; + + // Pointer to the allocating allocator instance. + // Note: A critical assumption here is that the allocating allocator is not destroyed before the allocation is freed. + HtpSharedMemoryAllocator* allocator_ptr; + + AllocationHeader(HtpSharedMemoryAllocator* allocator_ptr) + : marker{kAllocationHeaderMarker}, + allocator_ptr{allocator_ptr} { + } + + ~AllocationHeader() { + marker.fill('\0'); + allocator_ptr = nullptr; + } +}; + +size_t AllocationAlignment() { + return std::max(alignof(AllocationHeader), MlasGetPreferredBufferAlignment()); +} + +size_t DivRoundUp(size_t a, size_t b) { // TODO is there already a helper function somewhere for this? + return (a + b - 1) / b; +} + +bool IsAligned(const void* address, size_t alignment) { + assert((alignment & alignment - 1) == 0); // alignment must be a power of two + return (reinterpret_cast(address) & (alignment - 1)) == 0; +} + +size_t AllocationOffsetFromStartOfHeader() { + const size_t allocation_alignment = AllocationAlignment(); + const size_t offset = DivRoundUp(sizeof(AllocationHeader), allocation_alignment) * allocation_alignment; + return offset; +} + +std::byte* GetAllocationHeaderAddress(void* allocation_address) { + auto* allocation_header_address = reinterpret_cast(allocation_address) - sizeof(AllocationHeader); + return allocation_header_address; +} + +AllocationHeader& ValidateAllocationAddressAndGetHeader(void* allocation_address) { + const size_t allocation_alignment = AllocationAlignment(); + ORT_ENFORCE(IsAligned(allocation_address, allocation_alignment), + "Allocation address (", allocation_address, ") does not have required alignment (", + allocation_alignment, " bytes)."); + + auto* allocation_header = reinterpret_cast(GetAllocationHeaderAddress(allocation_address)); + ORT_ENFORCE(allocation_header->marker == AllocationHeader::kAllocationHeaderMarker, + "AllocationHeader for allocation address (", allocation_address, + ") does not have the expected marker bytes."); + + return *allocation_header; +} + +std::unique_ptr WrapSharedMemoryWithUniquePtr(void* shared_memory_raw, + const RpcMemApi& rpcmem_api) { + return {shared_memory_raw, rpcmem_api.free}; +} + +} // namespace + +OrtMemoryInfo HtpSharedMemoryAllocator::AssociatedMemoryInfo() { + return OrtMemoryInfo{QNN_HTP_SHARED, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device_id */ 0}, + /* id */ 0, OrtMemTypeDefault}; +} + +HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, + const logging::Logger* logger) + : IAllocator{AssociatedMemoryInfo()}, + rpcmem_lib_{std::move(rpcmem_lib)}, + logger_(logger != nullptr ? *logger : logging::LoggingManager::DefaultLogger()) { + ORT_ENFORCE(rpcmem_lib_ != nullptr); +} + +void* HtpSharedMemoryAllocator::Alloc(size_t requested_size) { + const size_t allocation_offset = AllocationOffsetFromStartOfHeader(); + const size_t shared_memory_block_size_in_bytes = allocation_offset + requested_size; + + // rpcmem_alloc() has an int size parameter. make sure we don't overflow. + constexpr size_t max_size_in_bytes = std::numeric_limits::max(); + ORT_ENFORCE(shared_memory_block_size_in_bytes <= max_size_in_bytes, + "Allocation size (", shared_memory_block_size_in_bytes, ") is larger than maximum allowed (", + max_size_in_bytes, ")."); + + // allocate shared memory + void* shared_memory_raw = rpcmem_lib_->Api().alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, + static_cast(shared_memory_block_size_in_bytes)); + ORT_ENFORCE(shared_memory_raw != nullptr, "rpcmem_alloc() failed to allocate and returned nullptr."); + auto shared_memory = WrapSharedMemoryWithUniquePtr(shared_memory_raw, rpcmem_lib_->Api()); + + const size_t allocation_alignment = AllocationAlignment(); + ORT_ENFORCE(IsAligned(shared_memory_raw, allocation_alignment), + "Shared memory address (", shared_memory_raw, ") does not have required alignment (", + allocation_alignment, " bytes)."); + + // get shared memory fd + const auto shared_memory_fd = rpcmem_lib_->Api().to_fd(shared_memory.get()); + ORT_ENFORCE(shared_memory_fd != -1, "rpcmem_to_fd() returned invalid file descriptor."); + + std::byte* allocation_address = reinterpret_cast(shared_memory_raw) + allocation_offset; + + // store allocation record + { + SharedMemoryInfo shared_memory_info{}; + shared_memory_info.fd = shared_memory_fd; + shared_memory_info.offset = allocation_offset; + shared_memory_info.total_size = shared_memory_block_size_in_bytes; + + AllocationRecord allocation_record{}; + allocation_record.shared_memory_info = std::move(shared_memory_info); + + std::scoped_lock g{allocations_mutex_}; + const bool inserted = allocations_.emplace(allocation_address, std::move(allocation_record)).second; + ORT_ENFORCE(inserted, "Allocation record already exists for address (", allocation_address, ")."); + } + + // initialize header + { + std::byte* allocation_header_address = GetAllocationHeaderAddress(allocation_address); + new (allocation_header_address) AllocationHeader(this); + } + + shared_memory.release(); + return allocation_address; +} + +void HtpSharedMemoryAllocator::Free(void* allocation_address) { + if (allocation_address == nullptr) { + return; + } + + auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); + ORT_ENFORCE(allocation_header.allocator_ptr == this, + "AllocationHeader points to a different allocator (", allocation_header.allocator_ptr, + ") than this one (", this, ")."); + + const auto allocation_node = [this, allocation_address]() { + std::scoped_lock g{allocations_mutex_}; + return allocations_.extract(allocation_address); + }(); + + ORT_ENFORCE(!allocation_node.empty(), "Failed to get allocation info for address (", allocation_address, ")."); + + // At this point, we have a valid allocation to free. + // Avoid throwing exceptions as this may be running from a destructor. + try { + // take ownership of shared memory and free at end of scope + auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); + + // destroy header + allocation_header.~AllocationHeader(); + + // clean up allocation record + const auto& allocation_record = allocation_node.mapped(); + for (auto& clean_up_fn : allocation_record.clean_up_fns) { + // attempt to run each clean_up_fn even if exceptions are thrown + try { + clean_up_fn(allocation_address); + } catch (const std::exception& e) { + LOGS(logger_, ERROR) << "Caught exception while running clean up callback for address (" << allocation_address + << "): " << e.what(); + } + } + } catch (const std::exception& e) { + LOGS(logger_, ERROR) << "Caught exception while freeing address (" << allocation_address << "): " << e.what(); + } +} + +Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(void* allocation_address, + SharedMemoryInfo& allocation_info) { + auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); + return allocation_header.allocator_ptr->GetAllocationSharedMemoryInfoForThisAllocator(allocation_address, + allocation_info); +} + +Status HtpSharedMemoryAllocator::AddAllocationCleanUp(void* allocation_address, + AllocationCleanUpFn&& allocation_clean_up) { + auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); + return allocation_header.allocator_ptr->AddAllocationCleanUpForThisAllocator(allocation_address, + std::move(allocation_clean_up)); +} + +Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, + SharedMemoryInfo& allocation_info) { + std::scoped_lock g{allocations_mutex_}; + const auto allocation_it = allocations_.find(allocation_address); + ORT_RETURN_IF(allocation_it == allocations_.end(), + "Failed to get allocation info for address (", allocation_address, ")."); + + allocation_info = allocation_it->second.shared_memory_info; + return Status::OK(); +} + +Status HtpSharedMemoryAllocator::AddAllocationCleanUpForThisAllocator(void* allocation_address, + AllocationCleanUpFn&& allocation_clean_up) { + ORT_RETURN_IF(allocation_clean_up == nullptr, "allocation_clean_up should not be empty."); + + std::scoped_lock g{allocations_mutex_}; + const auto allocation_it = allocations_.find(allocation_address); + ORT_RETURN_IF(allocation_it == allocations_.end(), + "Failed to get allocation info for address (", allocation_address, ")."); + + auto& clean_up_fns = allocation_it->second.clean_up_fns; + clean_up_fns.emplace_back(std::move(allocation_clean_up)); + return Status::OK(); +} + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h new file mode 100644 index 0000000000000..5b854a70fc00f --- /dev/null +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +#include "core/common/status.h" +#include "core/framework/allocator.h" +#include "core/providers/qnn/rpcmem_library.h" + +namespace onnxruntime::qnn { + +class HtpSharedMemoryAllocator : public IAllocator { + public: + // Gets the OrtMemoryInfo value that is associated with this allocator type. + static OrtMemoryInfo AssociatedMemoryInfo(); + + HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, + const logging::Logger* logger = nullptr); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(HtpSharedMemoryAllocator); + + // IAllocator overrides + + void* Alloc(size_t size) override; + void Free(void* p) override; + // void GetStats(AllocatorStats* stats) override; // TODO override + + struct SharedMemoryInfo { + int fd; + uint64_t offset; + uint64_t total_size; + }; + + // Get an allocation's shared memory info. + // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been + // freed. + static Status GetAllocationSharedMemoryInfo(void* allocation_address, + SharedMemoryInfo& allocation_info); + + using AllocationCleanUpFn = std::function; + + // Add allocation clean up callback to call when the allocation is freed. + // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been + // freed. + // `allocation_clean_up` is the clean up callback. The associated allocator takes ownership of the callback. + static Status AddAllocationCleanUp(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); + + private: + Status GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, + SharedMemoryInfo& allocation_info); + + Status AddAllocationCleanUpForThisAllocator(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); + + struct AllocationRecord { + SharedMemoryInfo shared_memory_info; + InlinedVector clean_up_fns; + }; + + // allocation address -> corresponding allocation record + InlinedHashMap allocations_; + std::mutex allocations_mutex_; // synchronize access to allocations_ + + std::shared_ptr rpcmem_lib_; + + const logging::Logger& logger_; +}; + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 27e195dea73d2..cfa0555e8149d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -5,24 +5,27 @@ #include #include + #include "core/framework/compute_capability.h" -#include "core/graph/graph_viewer.h" -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/onnxruntime_run_options_config_keys.h" -#include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" +#include "core/framework/run_options.h" +#include "core/graph/graph_viewer.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/partitioning_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_def.h" -#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" -#include "core/framework/run_options.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group.h" +#include "core/providers/qnn/qnn_allocator.h" +#include "core/providers/qnn/rpcmem_library.h" +#include "core/providers/qnn/shared_context.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #ifdef _WIN32 #include @@ -390,7 +393,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio << "handles the graph I/O quantization/dequantization."; } - qnn_backend_manager_ = std::make_unique( + static const std::string QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED = "enable_htp_shared_memory_allocator"; + if (ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map)) { + // Initialize rpcmem_library_. + // This is necessary for HtpSharedMemoryAllocator to function and also indicates that the allocator is available. + rpcmem_library_ = std::make_shared(); + } + + qnn_backend_manager_ = std::make_shared( std::move(backend_path), profiling_level_etw, profiling_level, @@ -1176,4 +1186,25 @@ Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::R return Status::OK(); } + +std::vector QNNExecutionProvider::CreatePreferredAllocators() { + std::vector allocators{}; + + if (IsHtpSharedMemoryAllocatorAvailable()) { + LOGS_DEFAULT(INFO) << "Creating HtpSharedMemoryAllocator."; + + AllocatorFactory rpcmem_allocator_factory = [this](OrtDevice::DeviceId) { + return std::make_unique(rpcmem_library_); + }; + + AllocatorCreationInfo rpcmem_allocator_creation_info{rpcmem_allocator_factory, + /* device_id */ 0, + /* use_arena */ false}; + + allocators.emplace_back(CreateAllocator(rpcmem_allocator_creation_info)); + } + + return allocators; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index a0577e8fd87f2..317b34e66a6e4 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -7,13 +7,15 @@ #include "core/framework/session_options.h" #include "core/framework/model_metadef_id_generator.h" #include "core/graph/model.h" -#include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "core/providers/qnn/rpcmem_library.h" #include "HTP/QnnHtpGraph.h" +#include #include #include +#include #include #ifdef _WIN32 #include "core/platform/windows/logging/etw_sink.h" @@ -23,67 +25,6 @@ namespace onnxruntime { void RunOnUnload(std::function function); -class SharedContext { - public: - static SharedContext& GetInstance() { - static SharedContext instance_; - return instance_; - } - - bool HasSharedQnnModels() { - const std::lock_guard lock(mtx_); - return !shared_qnn_models_.empty(); - } - - bool HasQnnModel(const std::string& model_name) { - auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); - return it != shared_qnn_models_.end(); - } - - std::unique_ptr GetSharedQnnModel(const std::string& model_name) { - const std::lock_guard lock(mtx_); - auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); - if (it == shared_qnn_models_.end()) { - return nullptr; - } - auto qnn_model = std::move(*it); - shared_qnn_models_.erase(it); - return qnn_model; - } - - bool SetSharedQnnModel(std::vector>&& shared_qnn_models, - std::string& duplicate_graph_names) { - const std::lock_guard lock(mtx_); - bool graph_exist = false; - for (auto& shared_qnn_model : shared_qnn_models) { - auto& model_name = shared_qnn_model->Name(); - auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); - if (it == shared_qnn_models_.end()) { - shared_qnn_models_.push_back(std::move(shared_qnn_model)); - } else { - duplicate_graph_names.append(model_name + " "); - graph_exist = true; - } - } - - return graph_exist; - } - - private: - SharedContext() = default; - ~SharedContext() = default; - SharedContext(const SharedContext&) = delete; - SharedContext& operator=(const SharedContext&) = delete; - - std::vector> shared_qnn_models_; - // Producer sessions can be in parallel - // Consumer sessions have to be after producer sessions initialized - std::mutex mtx_; -}; - // Logical device representation. class QNNExecutionProvider : public IExecutionProvider { public: @@ -113,6 +54,8 @@ class QNNExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + std::vector CreatePreferredAllocators() override; + private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -132,9 +75,13 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); + bool IsHtpSharedMemoryAllocatorAvailable() const { return rpcmem_library_ != nullptr; } + private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; - std::unique_ptr qnn_backend_manager_; + // Note: Using shared_ptr so that we can refer to it with a weak_ptr from a + // HtpSharedMemoryAllocator allocation cleanup callback. + std::shared_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; @@ -155,6 +102,10 @@ class QNNExecutionProvider : public IExecutionProvider { #endif qnn::ModelSettings model_settings_ = {}; + // Whether this is set depends on a session option enabling it and if the RPCMEM dynamic library is available. + // This is potentially shared with HtpSharedMemoryAllocator which may be returned by CreatePreferredAllocators(). + std::shared_ptr rpcmem_library_ = nullptr; + class PerThreadContext final { public: PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc new file mode 100644 index 0000000000000..77a340ddfcea1 --- /dev/null +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#include "core/providers/qnn/rpcmem_library.h" + +#include "core/common/logging/logging.h" +#include "core/platform/env.h" + +namespace onnxruntime::qnn { + +namespace { + +const PathChar* GetRpcMemSharedLibraryPath() { +#if defined(_WIN32) + return ORT_TSTR("libcdsprpc.dll"); +#else + return ORT_TSTR("libcdsprpc.so"); +#endif +} + +DynamicLibraryHandle LoadDynamicLibrary(const PathString& path, bool global_symbols) { + // Custom deleter to unload the shared library. Avoid throwing from it because it may run in dtor. + const auto unload_library = [](void* library_handle) { + if (library_handle == nullptr) { + return; + } + + const auto& env = Env::Default(); + const auto unload_status = env.UnloadDynamicLibrary(library_handle); + + if (!unload_status.IsOK()) { + LOGS_DEFAULT(WARNING) << "Failed to unload shared library. Error: " << unload_status.ErrorMessage(); + } + }; + + const auto& env = Env::Default(); + void* library_handle = nullptr; + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(path, global_symbols, &library_handle)); + + return DynamicLibraryHandle{library_handle, unload_library}; +} + +RpcMemApi CreateApi(void* library_handle) { + RpcMemApi api{}; + + const auto& env = Env::Default(); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_alloc", (void**)&api.alloc)); + + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_free", (void**)&api.free)); + + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_to_fd", (void**)&api.to_fd)); + + return api; +} + +} // namespace + +RpcMemLibrary::RpcMemLibrary() + : library_handle_(LoadDynamicLibrary(GetRpcMemSharedLibraryPath(), /* global_symbols */ false)), + api_{CreateApi(library_handle_.get())} { +} + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.h b/onnxruntime/core/providers/qnn/rpcmem_library.h new file mode 100644 index 0000000000000..d5697ff298e79 --- /dev/null +++ b/onnxruntime/core/providers/qnn/rpcmem_library.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#pragma once + +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime::qnn { + +using DynamicLibraryHandle = std::unique_ptr; + +// This namespace contains constants and typedefs corresponding to functions from rpcmem.h. +// https://github.com/quic/fastrpc/blob/v0.1.1/inc/rpcmem.h +namespace rpcmem { + +constexpr uint32_t RPCMEM_DEFAULT_FLAGS = 1; + +constexpr int RPCMEM_HEAP_ID_SYSTEM = 25; + +/** + * Allocate a zero-copy buffer for size upto 2 GB with the FastRPC framework. + * Buffers larger than 2 GB must be allocated with rpcmem_alloc2 + * @param[in] heapid Heap ID to use for memory allocation. + * @param[in] flags ION flags to use for memory allocation. + * @param[in] size Buffer size to allocate. + * @return Pointer to the buffer on success; NULL on failure. + */ +using AllocFnPtr = void* (*)(int heapid, uint32_t flags, int size); + +/** + * Free a buffer and ignore invalid buffers. + */ +using FreeFnPtr = void (*)(void* po); + +/** + * Return an associated file descriptor. + * @param[in] po Data pointer for an RPCMEM-allocated buffer. + * @return Buffer file descriptor. + */ +using ToFdFnPtr = int (*)(void* po); + +} // namespace rpcmem + +// RPCMEM API function pointers. +struct RpcMemApi { + rpcmem::AllocFnPtr alloc; + rpcmem::FreeFnPtr free; + rpcmem::ToFdFnPtr to_fd; +}; + +// Loads and provides access to the RPCMEM API functions from a dynamically loaded library. +class RpcMemLibrary { + public: + RpcMemLibrary(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RpcMemLibrary); + + const RpcMemApi& Api() const { return api_; } + + private: + DynamicLibraryHandle library_handle_; + RpcMemApi api_; +}; + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h new file mode 100644 index 0000000000000..fdd3e411e0b7e --- /dev/null +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#include +#include +#include + +#include "core/common/common.h" +#include "core/providers/qnn/builder/qnn_model.h" + +#pragma once + +namespace onnxruntime { + +class SharedContext { + public: + static SharedContext& GetInstance() { + static SharedContext instance_; + return instance_; + } + + bool HasSharedQnnModels() { + const std::lock_guard lock(mtx_); + return !shared_qnn_models_.empty(); + } + + bool HasQnnModel(const std::string& model_name) { + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + return it != shared_qnn_models_.end(); + } + + std::unique_ptr GetSharedQnnModel(const std::string& model_name) { + const std::lock_guard lock(mtx_); + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + if (it == shared_qnn_models_.end()) { + return nullptr; + } + auto qnn_model = std::move(*it); + shared_qnn_models_.erase(it); + return qnn_model; + } + + bool SetSharedQnnModel(std::vector>&& shared_qnn_models, + std::string& duplicate_graph_names) { + const std::lock_guard lock(mtx_); + bool graph_exist = false; + for (auto& shared_qnn_model : shared_qnn_models) { + auto& model_name = shared_qnn_model->Name(); + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + if (it == shared_qnn_models_.end()) { + shared_qnn_models_.push_back(std::move(shared_qnn_model)); + } else { + duplicate_graph_names.append(model_name + " "); + graph_exist = true; + } + } + + return graph_exist; + } + + private: + SharedContext() = default; + ~SharedContext() = default; + SharedContext(const SharedContext&) = delete; + SharedContext& operator=(const SharedContext&) = delete; + + std::vector> shared_qnn_models_; + // Producer sessions can be in parallel + // Consumer sessions have to be after producer sessions initialized + std::mutex mtx_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/IOBinding.h b/onnxruntime/core/session/IOBinding.h index 1f1b3b8073f96..d5a1e273369a1 100644 --- a/onnxruntime/core/session/IOBinding.h +++ b/onnxruntime/core/session/IOBinding.h @@ -51,7 +51,7 @@ class IOBinding { /** * If the BindInput calls are async this function acts as a barrier to ensure all inputs are fully copied - * before you call the Run() method. There is no point calling Run() if you're inputs are not ready at the + * before you call the Run() method. There is no point calling Run() if your inputs are not ready at the * desired location. * This is a blocking call and is a wrapper over IExecutionProvider::Sync(). * Call InferenceSession::Run() only after calling this method or else you'll end up wasting cycles inside Run(). diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index ac5ea75453558..bebf6e98ff3fa 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -2,12 +2,17 @@ // Licensed under the MIT License. #include "allocator_adapters.h" +#include "core/framework/error_code_helper.h" #include "core/session/inference_session.h" #include "core/session/ort_env.h" #include "core/session/ort_apis.h" -#include "core/framework/error_code_helper.h" namespace onnxruntime { + +namespace { +constexpr uint32_t kOrtAllocatorReserveMinVersion = 18; +} // namespace + OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxruntime::AllocatorPtr&& i_allocator) : i_allocator_(std::move(i_allocator)) { OrtAllocator::version = ORT_API_VERSION; @@ -17,7 +22,7 @@ OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxrunti [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; - if (OrtAllocator::version >= 18) { + if (OrtAllocator::version >= kOrtAllocatorReserveMinVersion) { OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Reserve(size); }; } @@ -51,7 +56,7 @@ void* IAllocatorImplWrappingOrtAllocator::Alloc(size_t size) { } void* IAllocatorImplWrappingOrtAllocator::Reserve(size_t size) { - if (ort_allocator_->version >= 18 && ort_allocator_->Reserve) { + if (ort_allocator_->version >= kOrtAllocatorReserveMinVersion && ort_allocator_->Reserve) { return ort_allocator_->Reserve(ort_allocator_, size); } diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index f641c597acf07..88ad49329f929 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -82,7 +82,11 @@ class ModelTestBuilder { } template - NodeArg* MakeInput(const std::vector& shape, const std::vector& data) { + NodeArg* MakeInput(const std::vector& shape, const std::vector& data, + AllocatorPtr allocator = nullptr) { + if (!allocator) { + allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + } ONNX_NAMESPACE::TypeProto type_proto; type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); @@ -93,7 +97,7 @@ class ModelTestBuilder { } OrtValue input_value; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + CreateMLValue(allocator, shape, data, &input_value); @@ -104,17 +108,19 @@ class ModelTestBuilder { } template - NodeArg* MakeInput(const std::vector& shape, T min, T max) { - return MakeInput(shape, rand_gen_.Uniform(shape, min, max)); + NodeArg* MakeInput(const std::vector& shape, T min, T max, + AllocatorPtr allocator = nullptr) { + return MakeInput(shape, rand_gen_.Uniform(shape, min, max), allocator); } - NodeArg* MakeInputBool(const std::vector& shape) { + NodeArg* MakeInputBool(const std::vector& shape, + AllocatorPtr allocator = nullptr) { std::vector data_uint8 = rand_gen_.Uniform(shape, 0, 1); std::vector data; for (uint8_t x : data_uint8) { data.push_back(x != 0); } - return MakeInput(shape, data); + return MakeInput(shape, data, allocator); } template diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 23c3812ebd025..3a8a76144973f 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -101,7 +101,9 @@ namespace perftest { "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" - "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill fill buffer, used while generating QNN context binary.\n" + "\t [QNN only] [enable_htp_shared_memory_allocator]: Enable the QNN HTP shared memory allocator and use it for inputs and outputs.\n" + "\t Defaults to '0' (disabled).\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index a96028ed3903e..c0e8bfce82631 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -231,7 +231,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else if (key == "qnn_saver_path") { // no validation } else if (key == "htp_graph_finalization_optimization_mode") { - std::unordered_set supported_htp_graph_final_opt_modes = {"0", "1", "2", "3"}; + std::set supported_htp_graph_final_opt_modes = {"0", "1", "2", "3"}; if (supported_htp_graph_final_opt_modes.find(value) == supported_htp_graph_final_opt_modes.end()) { std::ostringstream str_stream; std::copy(supported_htp_graph_final_opt_modes.begin(), supported_htp_graph_final_opt_modes.end(), @@ -245,7 +245,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } } else if (key == "htp_arch") { - std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + std::set supported_htp_archs = {"0", "68", "69", "73", "75"}; if (supported_htp_archs.find(value) == supported_htp_archs.end()) { std::ostringstream str_stream; std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), @@ -253,8 +253,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") { - std::unordered_set supported_options = {"0", "1"}; + } else if (key == "enable_htp_fp16_precision" || + key == "offload_graph_io_quantization" || + key == "enable_htp_spill_fill_buffer" || + key == "enable_htp_shared_memory_allocator") { + std::set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), @@ -262,6 +265,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for ", key, ". select from: ", str); } + + if (key == "enable_htp_shared_memory_allocator" && value == "1") { + // if this option is set, also use the enabled allocator + device_memory_name_ = "QnnHtpShared"; + } } } session_options.AppendExecutionProvider("QNN", provider_options); @@ -836,8 +844,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); }; } else { Ort::MemoryInfo memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); - custom_allocator_ = std::make_unique(session_, memory_info); - allocator_ = *custom_allocator_; + custom_allocator_ = Ort::Allocator(session_, memory_info); + allocator_ = custom_allocator_; // free dimensions are treated as 1 if not overridden transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index 7d5e46983ad41..d6580812da8f0 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -39,7 +39,7 @@ class OnnxRuntimeTestSession : public TestSession { std::uniform_int_distribution dist_; std::vector> test_inputs_; OrtAllocator* allocator_ = Ort::AllocatorWithDefaultOptions(); - std::unique_ptr custom_allocator_; + Ort::Allocator custom_allocator_{nullptr}; std::vector outputs_; std::vector output_names_; // The same size with output_names_. diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index e8282dbad9f72..ed21ebbccc923 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -5,11 +5,12 @@ #include #include +#include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU +#include "core/providers/qnn/qnn_allocator.h" +#include "core/session/inference_session.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" -#include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU -#include "core/session/inference_session.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -1098,6 +1099,38 @@ TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { } } +TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { +#if !defined(__ANDROID__) && !defined(_WIN32) + // TODO there's probably a better way to check that we are on a Qualcomm device + GTEST_SKIP() << "Test is only supported on a Qualcomm device."; +#endif + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_shared_memory_allocator"] = "1"; + + AllocatorPtr htp_shared_memory_allocator{}; + { + auto allocators = QnnExecutionProviderWithOptions(provider_options)->CreatePreferredAllocators(); + ASSERT_FALSE(allocators.empty()); + auto& allocator = allocators[0]; + ASSERT_EQ(allocator->Info(), qnn::HtpSharedMemoryAllocator::AssociatedMemoryInfo()); + htp_shared_memory_allocator = std::move(allocator); + } + + auto input_defs = {TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f), + TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f)}; + RunQnnModelTest(BuildOpTestCase("Add", input_defs, {}, {}, kOnnxDomain, htp_shared_memory_allocator), + provider_options, + 13, + ExpectedEPNodeAssignment::All, + 0.008f); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index a8670252ff9e0..676460e108b0e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -901,10 +901,12 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, * * \param builder Model builder object used to build the model's inputs, outputs, and nodes. * \param input_def Input definition that describes what kind of input to create. + * \param allocator Optional allocator to use to allocate the input ORT value. * \return A pointer to the new input. */ template -inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def) { +inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def, + AllocatorPtr allocator = nullptr) { NodeArg* input = nullptr; const auto& shape = input_def.GetShape(); const bool is_initializer = input_def.IsInitializer(); @@ -915,7 +917,7 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& if (is_initializer) { input = builder.MakeInitializer(shape, raw_data); } else { - input = builder.MakeInput(shape, raw_data); + input = builder.MakeInput(shape, raw_data, allocator); } } else { // Random data const auto& rand_info = input_def.GetRandomDataInfo(); @@ -923,7 +925,7 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& if (is_initializer) { input = builder.MakeInitializer(shape, rand_info.min, rand_info.max); } else { - input = builder.MakeInput(shape, rand_info.min, rand_info.max); + input = builder.MakeInput(shape, rand_info.min, rand_info.max, allocator); } } @@ -931,7 +933,8 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& } template <> -inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def) { +inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def, + AllocatorPtr allocator) { NodeArg* input = nullptr; const auto& shape = input_def.GetShape(); const bool is_initializer = input_def.IsInitializer(); @@ -942,13 +945,13 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef(shape, raw_data); + input = builder.MakeInput(shape, raw_data, allocator); } } else { // Random data if (is_initializer) { input = builder.MakeRandInitializerBool(shape); } else { - input = builder.MakeInputBool(shape); + input = builder.MakeInputBool(shape, allocator); } } @@ -973,6 +976,7 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef @@ -980,18 +984,19 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, const std::vector>& input_defs_1, const std::vector>& input_defs_2, const std::vector& attrs, - const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs_1, input_defs_2, attrs, op_domain](ModelTestBuilder& builder) { + const std::string& op_domain = kOnnxDomain, + AllocatorPtr input_allocator = nullptr) { + return [op_type, input_defs_1, input_defs_2, attrs, op_domain, input_allocator](ModelTestBuilder& builder) { std::vector op_inputs; op_inputs.reserve(input_defs_1.size() + input_defs_2.size()); for (const auto& input_def : input_defs_1) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); op_inputs.push_back(input); } for (const auto& input_def : input_defs_2) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); op_inputs.push_back(input); } @@ -1012,6 +1017,8 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, * \param input_defs List of input definitions. * \param attrs List of operator attributes. * \param op_domain The operator's domain. Defaults to the ONNX domain (i.e., ""). + * \param use_contrib_qdq Whether to use Q/DQ ops from the MS domain instead of the ONNX domain. + * \param input_allocator Optional allocator to use to allocate input ORT values. * \returns A model building function. */ template @@ -1021,15 +1028,17 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( const std::vector>& non_quant_input_defs, const std::vector& attrs, const std::string& op_domain = kOnnxDomain, - bool use_contrib_qdq = false) { + bool use_contrib_qdq = false, + AllocatorPtr input_allocator = nullptr) { return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, - use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { + use_contrib_qdq, input_allocator]( + ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size()); // Create QDQ inputs for (const auto& input_def : quant_input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); QuantParams input_qparams = GetTestInputQuantParams(input_def); NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, use_contrib_qdq); @@ -1038,7 +1047,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( // Create non-QDQ inputs for (const auto& input_def : non_quant_input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); op_inputs.push_back(input); }