diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 8991bbe6a01ae..800a2b898526c 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -3,35 +3,18 @@ #include "core/framework/allocator.h" #include "core/framework/allocatormgr.h" -#include "core/mlas/inc/mlas.h" +#include "core/framework/utils.h" #include #include namespace onnxruntime { void* CPUAllocator::Alloc(size_t size) { - if (size <= 0) return nullptr; - void* p; - size_t alignment = MlasGetPreferredBufferAlignment(); -#if _MSC_VER - p = _aligned_malloc(size, alignment); - if (p == nullptr) throw std::bad_alloc(); -#elif defined(_LIBCPP_SGX_CONFIG) - p = memalign(alignment, size); - if (p == nullptr) throw std::bad_alloc(); -#else - int ret = posix_memalign(&p, alignment, size); - if (ret != 0) throw std::bad_alloc(); -#endif - return p; + return utils::DefaultAlloc(size); } void CPUAllocator::Free(void* p) { -#if _MSC_VER - _aligned_free(p); -#else - free(p); -#endif + utils::DefaultFree(p); } const OrtAllocatorInfo& CPUAllocator::Info() const { return *allocator_info_; } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index fdbea44864403..cb126236d15e9 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -16,9 +16,35 @@ #include "core/framework/parallel_executor.h" #include "core/framework/session_state.h" #include "core/framework/sequential_executor.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { namespace utils { +void* DefaultAlloc(size_t size) { + if (size <= 0) return nullptr; + void* p; + size_t alignment = MlasGetPreferredBufferAlignment(); +#if _MSC_VER + p = _aligned_malloc(size, alignment); + if (p == nullptr) throw std::bad_alloc(); +#elif defined(_LIBCPP_SGX_CONFIG) + p = memalign(alignment, size); + if (p == nullptr) throw std::bad_alloc(); +#else + int ret = posix_memalign(&p, alignment, size); + if (ret != 0) throw std::bad_alloc(); +#endif + return p; +} + +void DefaultFree(void* p) { +#if _MSC_VER + _aligned_free(p); +#else + free(p); +#endif +} + AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorInfo& allocator_info) { return session_state.GetExecutionProviders().GetAllocator(allocator_info); } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index b096f1ecbaf8b..881762da85a2d 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -25,6 +25,8 @@ class Logger; } namespace utils { +void* DefaultAlloc(size_t size); +void DefaultFree(void* p); AllocatorPtr GetAllocator(const SessionState& session_state, const OrtAllocatorInfo& allocator_info); diff --git a/onnxruntime/core/session/default_cpu_allocator_c_api.cc b/onnxruntime/core/session/default_cpu_allocator_c_api.cc index a4a7e7ea51612..2d7bc9843b4c8 100644 --- a/onnxruntime/core/session/default_cpu_allocator_c_api.cc +++ b/onnxruntime/core/session/default_cpu_allocator_c_api.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include "core/framework/utils.h" #include "core/session/onnxruntime_cxx_api.h" #include @@ -23,10 +24,10 @@ struct OrtDefaultAllocator : OrtAllocatorImpl { ~OrtDefaultAllocator() override { OrtReleaseAllocatorInfo(cpuAllocatorInfo); } void* Alloc(size_t size) { - return ::malloc(size); + return onnxruntime::utils::DefaultAlloc(size); } void Free(void* p) { - return ::free(p); + onnxruntime::utils::DefaultFree(p); } const OrtAllocatorInfo* Info() const { return cpuAllocatorInfo;