From d2a1b7a35398d6612a4d71272a6f784a913c855a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 27 Aug 2024 12:18:52 -0700 Subject: [PATCH] Introduce custom external data loader (#21634) ### Description This PR introduces support for custom external data loader. An EP can register a custom external data loader to override the default behavior, making it possible to upload initializers directly to GPU. ### Motivation and Context - In ONNX Runtime Web, WebAssembly uses 32-bit as pointer type (`sizeof(size_t)==4`), which means there is a 4GB hard limit on the maximum memory. As the ONNX models get larger, this becomes a blocker for supporting medium-sized language models. - ORT runs out of memory because the current code always loads data into CPU memory, including the .onnx file (protobuf) and external data file(s). However, if using GPU EP, the big data does not need to be kept on CPU because the only thing that ORT does is to load the data into memory, upload to GPU and then release them. - Some platforms has offered developers way to upload data directly to GPU. For example, webgpu allows uploading from any ArrayBuffer (it can be a side buffer, not count into the 4GB) to GPU directly. This helps to keep the CPU memory usage significantly. ### Design Class `ExternalDataLoader` and `ExternalDataLoaderManager` are introduced. They are similar to `DataTransfer` and `DataTransferManager`. `InferenceSession` owns the manager object, and `SessionState` keeps a reference to it. Added a new method `GetExternalDataLoader` in `IExecutionProvider`. An EP can override the method to register an instance of custom external data loader. The key function in a `ExternalDataLoader` class is method `LoadTensor`: ```c++ // the tensor is pre-created using the TensorProto info of the initializer and the MemoryInfo (from allocation plan). virtual common::Status LoadTensor(const Env& env, const std::filesystem::path& data_file_path, FileOffsetType data_offset, SafeInt data_length, Tensor& tensor) const; ``` This function can be registered by EP, going through a few layers and eventually get into `DeserializeTensorProto()` in the finalizing stage of session initialization. In this step, initializer tensors are created. Behavior is changed to first look up for a registered external data loader that can handle the current memory info. If any instance is available, use the loader; otherwise respect the old code path. --- .../core/framework/execution_provider.h | 14 +++ .../core/framework/external_data_loader.cc | 100 ++++++++++++++++++ .../core/framework/external_data_loader.h | 60 +++++++++++ .../framework/external_data_loader_manager.cc | 29 +++++ .../framework/external_data_loader_manager.h | 28 +++++ onnxruntime/core/framework/session_state.cc | 8 +- onnxruntime/core/framework/session_state.h | 6 ++ .../core/framework/session_state_utils.cc | 40 ++++--- .../core/framework/session_state_utils.h | 2 + .../core/framework/tensorprotoutils.cc | 84 ++++++--------- onnxruntime/core/framework/tensorprotoutils.h | 7 ++ onnxruntime/core/providers/js/allocator.cc | 6 +- onnxruntime/core/providers/js/allocator.h | 15 +-- .../core/providers/js/external_data_loader.cc | 42 ++++++++ .../core/providers/js/external_data_loader.h | 26 +++++ .../providers/js/js_execution_provider.cc | 9 +- .../core/providers/js/js_execution_provider.h | 1 + .../providers/shared_library/provider_api.h | 1 + onnxruntime/core/session/inference_session.cc | 13 +++ onnxruntime/core/session/inference_session.h | 9 ++ .../test/framework/allocation_planner_test.cc | 5 +- .../test/framework/execution_frame_test.cc | 15 ++- .../test/framework/session_state_test.cc | 23 +++- onnxruntime/test/providers/memcpy_test.cc | 3 +- onnxruntime/wasm/pre-jsep.js | 4 + 25 files changed, 448 insertions(+), 102 deletions(-) create mode 100644 onnxruntime/core/framework/external_data_loader.cc create mode 100644 onnxruntime/core/framework/external_data_loader.h create mode 100644 onnxruntime/core/framework/external_data_loader_manager.cc create mode 100644 onnxruntime/core/framework/external_data_loader_manager.h create mode 100644 onnxruntime/core/providers/js/external_data_loader.cc create mode 100644 onnxruntime/core/providers/js/external_data_loader.h diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 49c3d1bdd088a..a5b5d2edde46c 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -11,6 +11,7 @@ #include "core/common/logging/logging.h" #include "core/common/status.h" #include "core/framework/data_transfer.h" +#include "core/framework/external_data_loader.h" #include "core/framework/tensor.h" namespace onnxruntime { @@ -88,6 +89,19 @@ class IExecutionProvider { return nullptr; } + /** + * Returns an external data loader object that implements methods to load data from external sources. + * + * By default, framework will handle external data loading by loading the data into CPU memory and then copying + * it to the target device if required. So in most cases, it's not necessary to override this method. Specifically, + * in WebAssembly build, because the memory is limited and Web platform supports loading data from external sources + * directly into GPU memory, this method is overridden to provide a custom external data loader to avoid the extra + * CPU memory usage. + */ + virtual std::unique_ptr GetExternalDataLoader() const { + return nullptr; + } + /** * Interface for performing kernel lookup within kernel registries. * Abstracts away lower-level details about kernel registries and kernel matching. diff --git a/onnxruntime/core/framework/external_data_loader.cc b/onnxruntime/core/framework/external_data_loader.cc new file mode 100644 index 0000000000000..ea6c499829391 --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/external_data_loader.h" +#ifndef SHARED_PROVIDER +#include "core/framework/tensor.h" +#endif +#if defined(__wasm__) +#include +#endif + +namespace onnxruntime { + +common::Status IExternalDataLoader::LoadTensor([[maybe_unused]] const Env& env, + [[maybe_unused]] const std::filesystem::path& data_file_path, + [[maybe_unused]] FileOffsetType data_offset, + [[maybe_unused]] SafeInt data_length, + [[maybe_unused]] Tensor& tensor) const { + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); +} + +#if defined(__wasm__) + +common::Status LoadWebAssemblyExternalData(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + ExternalDataLoadType load_type, + void* tensor_data) { + auto err_code = EM_ASM_INT(({ + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (typeof Module == 'undefined' || !Module.MountedFiles) { + return 1; // "Module.MountedFiles" is not available. + } + let fileName = UTF8ToString($0 >>> 0); + if (fileName.startsWith('./')) { + fileName = fileName.substring(2); + } + const fileData = Module.MountedFiles.get(fileName); + if (!fileData) { + return 2; // File not found in preloaded files. + } + const offset = $1 >>> 0; + const length = $2 >>> 0; + const dataIdOrBuffer = $3 >>> 0; + const loadType = $4; + + if (offset + length > fileData.byteLength) { + return 3; // Out of bounds. + } + + try { + const data = fileData.subarray(offset, offset + length); + switch (loadType) { + case 0: + // Load external data to CPU memory. + // Copy the file data (fileData,offset,length) into WebAssembly memory + // (HEAPU8,buffer,length). + HEAPU8.set(data, dataIdOrBuffer); + break; + case 1: + // Load external data to GPU. + Module.jsepUploadExternalBuffer(dataIdOrBuffer, data); + break; + default: + return 4; // Unknown error occurred in memory copy. + } + return 0; + } catch { + return 4; + } + }), + data_file_path.c_str(), + static_cast(data_offset), + static_cast(data_length), + tensor_data, + static_cast(load_type)); + const char* err_msg; + switch (err_code) { + case 0: + return Status::OK(); + case 1: + err_msg = "Module.MountedFiles is not available."; + break; + case 2: + err_msg = "File not found in preloaded files."; + break; + case 3: + err_msg = "Out of bounds."; + break; + default: + err_msg = "Unknown error occurred in memory copy."; + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", data_file_path, + "\", error: ", err_msg); +} + +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/external_data_loader.h b/onnxruntime/core/framework/external_data_loader.h new file mode 100644 index 0000000000000..117da7d0a4afa --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/safeint.h" +#include "core/platform/env.h" + +struct OrtMemoryInfo; + +namespace onnxruntime { +#ifndef SHARED_PROVIDER +class Tensor; +#endif +class Stream; + +namespace common { +class Status; +} + +// Data transfer interface. +class IExternalDataLoader { + public: + virtual ~IExternalDataLoader() = default; + + virtual bool CanLoad(const OrtMemoryInfo& target_memory_info) const = 0; + + // Tensor should be already allocated with the correct memory info and size. + virtual common::Status LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const; +}; + +#if defined(__wasm__) + +enum class ExternalDataLoadType { + CPU = 0, +#if defined(USE_JSEP) + WEBGPU_BUFFER = 1, +#endif +}; + +// Entry point for loading external data implementation using inline JavaScript. +common::Status LoadWebAssemblyExternalData(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + ExternalDataLoadType load_type, + void* tensor_data); + +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/external_data_loader_manager.cc b/onnxruntime/core/framework/external_data_loader_manager.cc new file mode 100644 index 0000000000000..91161b1d3dd4c --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader_manager.cc @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/external_data_loader_manager.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +using namespace common; + +Status ExternalDataLoaderManager::RegisterExternalDataLoader(std::unique_ptr external_data_loader) { + if (nullptr == external_data_loader) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "external_data_loader registered is nullptr."); + } + external_data_loaders_.push_back(std::move(external_data_loader)); + return Status::OK(); +} + +const IExternalDataLoader* ExternalDataLoaderManager::GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const { + for (auto& external_data_loader : external_data_loaders_) { + if (!external_data_loader->CanLoad(target_memory_info)) { + continue; + } + + return external_data_loader.get(); + } + return nullptr; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/external_data_loader_manager.h b/onnxruntime/core/framework/external_data_loader_manager.h new file mode 100644 index 0000000000000..38881405c87ff --- /dev/null +++ b/onnxruntime/core/framework/external_data_loader_manager.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" +#include "core/common/common.h" +#include "core/framework/external_data_loader.h" + +namespace onnxruntime { + +// The external data loader manager manages all registered external data loaders to allow custom +// external data loading implemented by execution providers. +class ExternalDataLoaderManager { + public: + ExternalDataLoaderManager() = default; + + common::Status RegisterExternalDataLoader(std::unique_ptr external_data_loader); + + const IExternalDataLoader* GetExternalDataLoader(const OrtMemoryInfo& target_memory_info) const; + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalDataLoaderManager); + + // It's assumed that external data loaders in this array have no overlap in terms of copying functionality. + std::vector> external_data_loaders_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index ddb0c3356e544..4df0370ac719e 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -66,6 +66,7 @@ SessionState::SessionState(Graph& graph, concurrency::ThreadPool* thread_pool, concurrency::ThreadPool* inter_op_thread_pool, const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, const logging::Logger& logger, profiling::Profiler& profiler, const SessionOptions& sess_options, @@ -78,6 +79,7 @@ SessionState::SessionState(Graph& graph, thread_pool_(thread_pool), inter_op_thread_pool_(inter_op_thread_pool), data_transfer_mgr_(data_transfer_mgr), + external_data_loader_mgr_(external_data_loader_mgr), sess_options_(sess_options), prepacked_weights_container_(prepacked_weights_container) #ifdef ORT_ENABLE_STREAM @@ -1046,7 +1048,7 @@ Status SessionState::CreateSubgraphSessionState() { auto subgraph_session_state = std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, - logger_, profiler_, sess_options_, + external_data_loader_mgr_, logger_, profiler_, sess_options_, prepacked_weights_container_, allocators_); // Pass fused function manager to subgraph @@ -1486,8 +1488,8 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& GetMutableWeightsBuffers() noexcept { return weights_buffers_; } const NodeIndexInfo& GetNodeIndexInfo() const; @@ -513,6 +517,8 @@ class SessionState { const DataTransferManager& data_transfer_mgr_; + const ExternalDataLoaderManager& external_data_loader_mgr_; + const SessionOptions& sess_options_; std::optional node_index_info_; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 72f39245d3cfe..2c74805c57dce 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -99,6 +99,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, bool use_device_allocator_for_initializers = false, Tensor* buffered_tensor = nullptr) { if (bool(alloc) == (m != nullptr)) { @@ -114,12 +115,24 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); std::unique_ptr p_tensor; - auto device_type = (alloc != nullptr) ? alloc->Info().device.Type() : m->GetAllocInfo().device.Type(); + auto& memory_info = (alloc != nullptr) ? alloc->Info() : m->GetAllocInfo(); + auto device_type = memory_info.device.Type(); if (utils::HasExternalData(tensor_proto)) { - if (device_type == OrtDevice::CPU) { + auto external_data_loader = external_data_loader_mgr.GetExternalDataLoader(memory_info); + if (external_data_loader) { + // if custom external data loader is used, always allocate memory on device - p_tensor + ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); + + ORT_RETURN_IF_ERROR(utils::LoadExtDataToTensorFromTensorProto(env, proto_path, tensor_proto, + *external_data_loader, *p_tensor)); + + auto ml_tensor = DataTypeImpl::GetType(); + ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + return common::Status::OK(); + } else if (device_type == OrtDevice::CPU) { // for external initializer on CPU we will use mmap for large initializers so don't need to allocate memory in advance - p_tensor = std::make_unique(type, TensorShape(), alloc); + p_tensor = std::make_unique(); // NB: The file containing external data for the tensor is mmap'd. If the tensor will be used on CPU we can // utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called @@ -143,10 +156,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // 2. load initializer into CPU memory - p_deserialize_tensor, // we will use mmap so no need to allocate memory on CPU in advance // 3. copy tensor from CPU to device - p_deserialize_tensor -> p_tensor - auto allocate_on_device_status = AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc); - if (!allocate_on_device_status.IsOK()) { - return allocate_on_device_status; - } + ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); std::unique_ptr p_deserialize_tensor = std::make_unique(type, TensorShape(), default_cpu_alloc); @@ -161,10 +171,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } else { // for internal initializer, always allocate memory on device - p_tensor - auto allocate_on_device_status = AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc); - if (!allocate_on_device_status.IsOK()) { - return allocate_on_device_status; - } + ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); if (device_type == OrtDevice::CPU) { // deserialize directly to CPU tensor @@ -183,10 +190,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st // 2. deserialize tensor_probo into a preallocated tensor (p_deserialize_tensor) // 3. copy tensor from CPU to device - p_deserialize_tensor -> p_tensor std::unique_ptr p_deserialize_tensor; - auto allocate_on_cpu_status = AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, default_cpu_alloc, p_deserialize_tensor); - if (!allocate_on_cpu_status.IsOK()) { - return allocate_on_cpu_status; - } + ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, default_cpu_alloc, p_deserialize_tensor)); ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. @@ -262,7 +266,9 @@ common::Status SaveInitializedTensors( const std::vector& initializer_allocation_order, ITensorAllocator& planner, const SaveTensorFunction& save_tensor_func, - const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, + const logging::Logger& logger, + const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, @@ -394,7 +400,7 @@ common::Status SaveInitializedTensors( } Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, - default_cpu_alloc, ort_value, data_transfer_mgr, + default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, use_device_allocator_for_initializers, p_tensor); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index 89f4f2c340068..af27f5caba0f4 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -23,6 +23,7 @@ class SessionState; class GraphViewer; class OrtValueNameIdxMap; class DataTransferManager; +class ExternalDataLoaderManager; class NodeArg; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) class MemoryInfo; @@ -45,6 +46,7 @@ common::Status SaveInitializedTensors( const SaveTensorFunction& save_tensor_func, const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, + const ExternalDataLoaderManager& external_data_loader_mgr, const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 42f491825462c..74c359881a1d7 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1022,59 +1022,12 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo ext_data_buf = buffer.release(); ext_data_len = raw_data_safe_len; - // In WebAssembly, try use a simplified preloaded file map in WebAssembly when available. - auto err_code = EM_ASM_INT(({ - // If available, "Module.MountedFiles" is a Map for all preloaded files. - if (typeof Module == 'undefined' || !Module.MountedFiles) { - return 1; // "Module.MountedFiles" is not available. - } - let fileName = UTF8ToString($0 >>> 0); - if (fileName.startsWith('./')) { - fileName = fileName.substring(2); - } - const fileData = Module.MountedFiles.get(fileName); - if (!fileData) { - return 2; // File not found in preloaded files. - } - const offset = $1 >>> 0; - const length = $2 >>> 0; - const buffer = $3 >>> 0; - - if (offset + length > fileData.byteLength) { - return 3; // Out of bounds. - } - - try { - // Copy the file data (fileData,offset,length) into WebAssembly memory - // (HEAPU8,buffer,length). - HEAPU8.set(fileData.subarray(offset, offset + length), buffer); - return 0; - } catch { - return 4; - } - }), - external_data_file_path.c_str(), - static_cast(file_offset), - static_cast(raw_data_safe_len), - ext_data_buf); - const char* err_msg; - switch (err_code) { - case 0: - return Status::OK(); - case 1: - err_msg = "Module.MountedFiles is not available."; - break; - case 2: - err_msg = "File not found in preloaded files."; - break; - case 3: - err_msg = "Out of bounds."; - break; - default: - err_msg = "Unknown error occurred in memory copy."; - } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to load external data file \"", external_data_file_path, - "\", error: ", err_msg); + ORT_RETURN_IF_ERROR(LoadWebAssemblyExternalData(env, + external_data_file_path, + file_offset, + ext_data_len, + ExternalDataLoadType::CPU, + ext_data_buf)); #else // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to // manually check file size first. @@ -1095,6 +1048,31 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo return Status::OK(); } +Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const IExternalDataLoader& ext_data_loader, + Tensor& tensor) { + ORT_ENFORCE(utils::HasExternalData(tensor_proto)); + std::basic_string tensor_proto_dir; + if (!model_path.empty()) { + ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); + } + std::basic_string external_data_file_path; + FileOffsetType file_offset; + SafeInt raw_data_safe_len = 0; + ORT_RETURN_IF_ERROR( + GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len)); + + ORT_RETURN_IF(file_offset < 0 || raw_data_safe_len != tensor.SizeInBytes(), + "External initializer: ", tensor_proto.name(), " offset: ", file_offset, + " size to read: ", static_cast(raw_data_safe_len), + " does not match the tensor size: ", tensor.SizeInBytes()); + ORT_RETURN_IF(external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag, + "Memory address tag is not supported by custom external data loader."); + + return ext_data_loader.LoadTensor(env, external_data_file_path, file_offset, raw_data_safe_len, tensor); +} + #define CASE_PROTO(X, Y) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ ORT_RETURN_IF_ERROR( \ diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 2af1f080be7ee..227ba0706197e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -14,6 +14,7 @@ #include "core/common/safeint.h" #include "core/framework/endian_utils.h" #include "core/framework/allocator.h" +#include "core/framework/external_data_loader.h" #include "core/framework/ort_value.h" #include "core/framework/mem_buffer.h" #include "core/framework/tensor_external_data_info.h" @@ -159,6 +160,12 @@ common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem:: OrtCallback& ext_data_deleter, Tensor* buffered_tensor = nullptr); +// Given a tensor proto with external data obtain a tensor using the specified custom external data loader. +common::Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const IExternalDataLoader& ext_data_loader, + Tensor& tensor); + // Convert the AttributeProto from a Constant node into a TensorProto that can be used as an initializer // If AttributeProto contains a TensorProto, this tensor proto is converted as is including the case when the // the data location is external. i.e. it does not load the external data. diff --git a/onnxruntime/core/providers/js/allocator.cc b/onnxruntime/core/providers/js/allocator.cc index 574c507222a5c..d37346a166b03 100644 --- a/onnxruntime/core/providers/js/allocator.cc +++ b/onnxruntime/core/providers/js/allocator.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { -void* JsCustomAllocator::Alloc(size_t size) { +void* WebGpuAllocator::Alloc(size_t size) { if (size == 0) { return nullptr; } @@ -20,14 +20,14 @@ void* JsCustomAllocator::Alloc(size_t size) { return p; } -void JsCustomAllocator::Free(void* p) { +void WebGpuAllocator::Free(void* p) { if (p != nullptr) { size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); stats_.bytes_in_use -= size; } } -void JsCustomAllocator::GetStats(AllocatorStats* stats) { +void WebGpuAllocator::GetStats(AllocatorStats* stats) { *stats = stats_; } diff --git a/onnxruntime/core/providers/js/allocator.h b/onnxruntime/core/providers/js/allocator.h index 267015b2ea58d..aafb0bb22da7e 100644 --- a/onnxruntime/core/providers/js/allocator.h +++ b/onnxruntime/core/providers/js/allocator.h @@ -9,20 +9,11 @@ namespace onnxruntime { namespace js { -class JsCPUAllocator : public CPUAllocator { +class WebGpuAllocator : public IAllocator { public: - JsCPUAllocator() - : CPUAllocator( - OrtMemoryInfo("JsCPUAllocator", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), - 0, OrtMemTypeCPU)) {}; -}; - -class JsCustomAllocator : public IAllocator { - public: - JsCustomAllocator() + WebGpuAllocator() : IAllocator( - OrtMemoryInfo("JsCustomAllocator", OrtAllocatorType::OrtDeviceAllocator, + OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, OrtMemTypeDefault)) { } diff --git a/onnxruntime/core/providers/js/external_data_loader.cc b/onnxruntime/core/providers/js/external_data_loader.cc new file mode 100644 index 0000000000000..193b373cf3696 --- /dev/null +++ b/onnxruntime/core/providers/js/external_data_loader.cc @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "external_data_loader.h" + +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace js { + +bool ExternalDataLoader::CanLoad(const OrtMemoryInfo& target_memory_info) const { + return target_memory_info.device.Type() == OrtDevice::CPU +#if defined(USE_JSEP) + || (target_memory_info.device.Type() == OrtDevice::GPU && target_memory_info.name == WEBGPU_BUFFER) +#endif + ; +} + +common::Status ExternalDataLoader::LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const { + ExternalDataLoadType load_type; + if (tensor.Location().device.Type() == OrtDevice::CPU) { + load_type = ExternalDataLoadType::CPU; +#if defined(USE_JSEP) + } else if (tensor.Location().device.Type() == OrtDevice::GPU && + tensor.Location().name == WEBGPU_BUFFER) { + load_type = ExternalDataLoadType::WEBGPU_BUFFER; +#endif + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported tensor location: ", tensor.Location().ToString()); + } + + return LoadWebAssemblyExternalData(env, data_file_path, data_offset, data_length, load_type, tensor.MutableDataRaw()); +} + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/external_data_loader.h b/onnxruntime/core/providers/js/external_data_loader.h new file mode 100644 index 0000000000000..5f35ed62bbcc1 --- /dev/null +++ b/onnxruntime/core/providers/js/external_data_loader.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/external_data_loader.h" + +namespace onnxruntime { +namespace js { + +class ExternalDataLoader : public IExternalDataLoader { + public: + ExternalDataLoader() {}; + ~ExternalDataLoader() {}; + + bool CanLoad(const OrtMemoryInfo& target_memory_info) const override; + + common::Status LoadTensor(const Env& env, + const std::filesystem::path& data_file_path, + FileOffsetType data_offset, + SafeInt data_length, + Tensor& tensor) const override; +}; + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 781083ea8707c..1ff33f6d7b410 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -22,6 +22,7 @@ #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" #include "data_transfer.h" +#include "external_data_loader.h" namespace onnxruntime { @@ -737,9 +738,9 @@ JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, co std::vector JsExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo customAllocatorCreationInfo([&](int) { - return std::make_unique(); + return std::make_unique(); }, - 0, false); // TODO(leca): REVIEW: need JsCPUAllocator? + 0, false); return std::vector{CreateAllocator(customAllocatorCreationInfo)}; } @@ -797,6 +798,10 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer return std::make_unique(); } +std::unique_ptr JsExecutionProvider::GetExternalDataLoader() const { + return std::make_unique(); +} + JsExecutionProvider::~JsExecutionProvider() { } diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index efacf510e75df..966f9c6980212 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -48,6 +48,7 @@ class JsExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; + std::unique_ptr GetExternalDataLoader() const override; DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 0bfa52e7869cc..b84825236a453 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -221,6 +221,7 @@ using NameMLValMap = std::unordered_map; #include "core/providers/cpu/math/einsum_utils/einsum_compute_preprocessor.h" #include "core/providers/cpu/cpu_provider_shared.h" #include "core/framework/data_transfer.h" +#include "core/framework/external_data_loader.h" #include "core/framework/execution_provider.h" #include "provider_interfaces.h" #include "provider_wrappedtypes.h" diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4b4136fd6ebd5..b9e017df5baa3 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -830,6 +830,14 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } + auto p_external_data_loader = p_exec_provider->GetExternalDataLoader(); + if (p_external_data_loader) { + auto st = external_data_loader_mgr_.RegisterExternalDataLoader(std::move(p_external_data_loader)); + if (!st.IsOK()) { + return st; + } + } + p_exec_provider->SetLogger(session_logger_); session_profiler_.AddEpProfilers(p_exec_provider->GetProfiler()); return execution_providers_.Add(provider_type, p_exec_provider); @@ -1731,6 +1739,7 @@ common::Status InferenceSession::Initialize() { GetIntraOpThreadPoolToUse(), GetInterOpThreadPoolToUse(), data_transfer_mgr_, + external_data_loader_mgr_, *session_logger_, session_profiler_, session_options_, @@ -2152,6 +2161,10 @@ const DataTransferManager& InferenceSession::GetDataTransferManager() const { return data_transfer_mgr_; } +const ExternalDataLoaderManager& InferenceSession::GetExternalDataLoaderManager() const { + return external_data_loader_mgr_; +} + common::Status InferenceSession::CheckShapes(const std::string& input_output_name, const TensorShape& input_output_shape, const TensorShape& expected_shape, const char* input_output_moniker) const { const auto shape_size = input_output_shape.NumDimensions(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 9662095bf0ed3..8c22fac4dd0c5 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -18,6 +18,7 @@ #include "core/framework/execution_providers.h" #include "core/framework/framework_common.h" #include "core/framework/iexecutor.h" +#include "core/framework/external_data_loader_manager.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state.h" @@ -454,6 +455,11 @@ class InferenceSession { */ const DataTransferManager& GetDataTransferManager() const; + /* + * Get the GetExternalDataLoaderManager associated with this session + */ + const ExternalDataLoaderManager& GetExternalDataLoaderManager() const; + /* * Get all the providers' options this session was initialized with. */ @@ -784,6 +790,9 @@ class InferenceSession { // Data transfer manager. DataTransferManager data_transfer_mgr_; + // External data loader manager. + ExternalDataLoaderManager external_data_loader_mgr_; + // Number of concurrently running executors std::atomic current_num_runs_ = 0; diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 43d3782be3280..0105e90b5a24a 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -160,6 +160,7 @@ class PlannerTest : public ::testing::Test { ExecutionProviders execution_providers_; std::unique_ptr tp_; DataTransferManager dtm_; + ExternalDataLoaderManager edlm_; profiling::Profiler profiler_; std::unique_ptr sess_options_; std::unique_ptr state_; @@ -198,7 +199,7 @@ class PlannerTest : public ::testing::Test { sess_options_->enable_mem_pattern = false; sess_options_->use_deterministic_compute = false; sess_options_->enable_mem_reuse = true; - state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, + state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); } @@ -282,7 +283,7 @@ class PlannerTest : public ::testing::Test { } void CreatePlan(const std::vector& outer_scope_node_args = {}, bool invoke_createPlan_explicityly = true) { - state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, + state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); EXPECT_EQ(graph_.Resolve(), Status::OK()); diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index b95fd0b726a4e..67a0e7fb05241 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -59,6 +59,7 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -67,7 +68,7 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); node->SetExecutionProviderType(xp_typ); @@ -143,6 +144,7 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -151,7 +153,7 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); node->SetExecutionProviderType(xp_typ); @@ -215,6 +217,7 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -223,7 +226,7 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); @@ -287,6 +290,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { // 1. prepare input DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -295,7 +299,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); @@ -402,10 +406,11 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) { ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions so; - SessionState state(graph, execution_providers, &tp_, nullptr, dtm, DefaultLoggingManager().DefaultLogger(), + SessionState state(graph, execution_providers, &tp_, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, so); ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index ed698ab920147..b94d24a1b180b 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -61,6 +61,7 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider))); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -69,7 +70,7 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState s(graph, execution_providers, tp.get(), nullptr, dtm, + SessionState s(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); std::vector inputs; @@ -159,6 +160,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { ASSERT_TRUE(status.IsOK()) << status; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -167,7 +169,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, + SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); GraphPartitioner partitioner(krm, execution_providers); @@ -239,6 +241,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { ASSERT_TRUE(status.IsOK()) << status; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -250,7 +253,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { ASSERT_STATUS_OK(sess_options.config_options.AddConfigEntry(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "1")); - SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, + SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); // Partition the graph @@ -300,6 +303,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { ASSERT_TRUE(status.IsOK()) << status; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -308,7 +312,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, + SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); // Partition the graph @@ -545,6 +549,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_execution_provider))); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; std::unordered_map domain_to_version; @@ -573,6 +578,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -604,6 +610,7 @@ class SessionStateTestSharedInitalizersWithPrePacking : public ::testing::Test { ExecutionProviders execution_providers; std::unordered_map domain_to_version; DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; KernelRegistryManager kernel_registry_manager; std::unique_ptr tp; @@ -661,6 +668,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test1) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -687,6 +695,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test1) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -734,6 +743,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test2) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -760,6 +770,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test2) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); @@ -809,6 +820,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, @@ -840,6 +852,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, @@ -895,6 +908,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, @@ -945,6 +959,7 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { tp.get(), nullptr, /*inter_op_thread_pool*/ dtm, + edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options, diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index b0cdb7dc97773..4efa359b4e589 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -47,6 +47,7 @@ TEST(MemcpyTest, copy1) { PutAllNodesOnOneProvider(model.MainGraph(), onnxruntime::kCpuExecutionProvider); DataTransferManager dtm; + ExternalDataLoaderManager edlm; profiling::Profiler profiler; SessionOptions sess_options; @@ -55,7 +56,7 @@ TEST(MemcpyTest, copy1) { sess_options.use_deterministic_compute = false; sess_options.enable_mem_reuse = true; - SessionState s(model.MainGraph(), execution_providers, &tp, nullptr, dtm, + SessionState s(model.MainGraph(), execution_providers, &tp, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); ASSERT_STATUS_OK(s.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 1cb7c6f5d8250..70ed295887994 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -198,5 +198,9 @@ Module['jsepInit'] = (name, params) => { Module['jsepOnRunStart'] = sessionId => { return backend['onRunStart'](sessionId); }; + + Module.jsepUploadExternalBuffer = (dataId, buffer) => { + backend['upload'](dataId, buffer); + }; } };