Skip to content

Commit 80daa9b

Browse files
Reduce the peak memory even with CPU fallback by moving the fallback within the basic_backend.cc scope (#723)
1 parent 13e7792 commit 80daa9b

File tree

3 files changed

+111
-110
lines changed

3 files changed

+111
-110
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -153,52 +153,28 @@ BackendManager::BackendManager(SessionContext& session_context,
153153
model_stream);
154154
} catch (const OnnxRuntimeException& ex) {
155155
std::string exception_str = ex.what();
156-
bool eligible_for_cpu_fallback = device_type.find("NPU") != std::string::npos &&
157-
!session_context_.so_disable_cpu_ep_fallback &&
158-
!subgraph_context_.is_ep_ctx_graph;
159-
#if defined(OPENVINO_DISABLE_NPU_FALLBACK)
160-
eligible_for_cpu_fallback = false;
161-
#else
162-
if (eligible_for_cpu_fallback) {
163-
LOGS_DEFAULT(VERBOSE) << exception_str;
164-
LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."
165-
<< "Falling back to OV CPU for execution";
166-
session_context_.device_type = "CPU";
167-
session_context_.precision = "FP32";
168-
try {
169-
concrete_backend_ = BackendFactory::MakeBackend(model_proto,
170-
session_context_,
171-
subgraph_context_,
172-
shared_context_,
173-
model_stream);
174-
} catch (std::string const& msg) {
175-
ORT_THROW(msg);
176-
}
177-
}
178-
#endif
179-
if (!eligible_for_cpu_fallback) {
180-
if (device_type.find("NPU") != std::string::npos &&
181-
exception_str.find("intel_npu") != std::string::npos) {
182-
// Handle NPU device related errors
156+
157+
if (session_context_.device_type.find("NPU") != std::string::npos &&
158+
exception_str.find("intel_npu") != std::string::npos) {
159+
// Handle NPU device related errors
183160
#ifndef NDEBUG
184-
ORT_THROW(exception_str + "\nModel needs to be recompiled\n");
161+
ORT_THROW(exception_str + "\nModel needs to be recompiled\n");
185162
#else
186-
std::string error_message = "UNKNOWN NPU ERROR";
187-
std::string error_code = "code 0x0";
188-
std::regex error_message_pattern(R"(\bZE_\w*\b)");
189-
std::regex error_code_pattern("code 0x[0-9a-fA-F]+");
190-
std::smatch matches;
191-
if (std::regex_search(exception_str, matches, error_message_pattern)) {
192-
error_message = matches[0];
193-
}
194-
if (std::regex_search(exception_str, matches, error_code_pattern)) {
195-
error_code = matches[0];
196-
}
197-
throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n");
198-
#endif
199-
} else {
200-
ORT_THROW(exception_str);
163+
std::string error_message = "UNKNOWN NPU ERROR";
164+
std::string error_code = "code 0x0";
165+
std::regex error_message_pattern(R"(\bZE_\w*\b)");
166+
std::regex error_code_pattern("code 0x[0-9a-fA-F]+");
167+
std::smatch matches;
168+
if (std::regex_search(exception_str, matches, error_message_pattern)) {
169+
error_message = matches[0];
201170
}
171+
if (std::regex_search(exception_str, matches, error_code_pattern)) {
172+
error_code = matches[0];
173+
}
174+
throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n");
175+
#endif
176+
} else {
177+
ORT_THROW(exception_str);
202178
}
203179
}
204180
}

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 91 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -36,42 +36,14 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
3636
if (ValidateSubgraph(const_outputs_map_))
3737
return;
3838

39-
// OV Config
39+
// Pre-requisite is provider_option "context" must be set
40+
auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) ||
41+
(session_context_.OpenVINO_Version.at(0) >= 2024 &&
42+
session_context_.OpenVINO_Version.at(1) > 2));
4043
ov::AnyMap device_config;
41-
PopulateConfigValue(device_config);
42-
43-
// Enable caching
44-
EnableCaching();
45-
46-
// Setting OpenCL queue throttling for GPU
47-
EnableGPUThrottling(device_config);
48-
49-
// Enable streams; default=1 unless overridden by user configuration
50-
EnableStreams();
51-
52-
// Set the inference_num_threads property of the CPU
53-
SetNumThreads(device_config);
54-
55-
auto npuw_status =
56-
std::any_of(device_config.begin(), device_config.end(), [&](const std::pair<std::string, ov::Any>& pair) {
57-
return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is<std::string>()) &&
58-
(pair.second.as<std::string>() == "YES");
59-
});
60-
61-
if (npuw_status) {
62-
LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation";
63-
}
64-
65-
try {
66-
// IO_BUFFER is enabled on GPU HW.
67-
// Pre-requisite is provider_option "context" must be set
68-
auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) ||
69-
(session_context_.OpenVINO_Version.at(0) >= 2024 &&
70-
session_context_.OpenVINO_Version.at(1) > 2));
71-
bool disable_cpu_fallback = !(hw_target.find("NPU") != std::string::npos &&
72-
!session_context_.so_disable_cpu_ep_fallback &&
73-
!subgraph_context_.is_ep_ctx_graph);
74-
if (subgraph_context_.is_ep_ctx_graph) {
44+
SetOVDeviceConfiguration(device_config);
45+
if (subgraph_context_.is_ep_ctx_graph) {
46+
try {
7547
if (subgraph_context_.is_ep_ctx_ovir_encapsulated) {
7648
// model_file_path will use so_context_file_path if the onnx_model_path_name is not available,
7749
// especially in case of CreateSessionFormArray() where user must explicitly
@@ -104,41 +76,67 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
10476
device_config,
10577
subgraph_context_.subgraph_name);
10678
}
107-
model_stream.reset(); // Delete stream after it is no longer needed
108-
} else if (!session_context_.has_external_weights &&
109-
!subgraph_context_.has_dynamic_input_shape &&
110-
!session_context_.so_context_enable &&
111-
session_context_.reshape.empty() &&
112-
!enable_causallm &&
113-
auto_unified_compile) {
114-
// Unified OV compile_model is efficient when ov model caching is enabled
115-
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
116-
// Inputs with static dimensions
117-
// Not enabled for models with external weights and when ep context is set.
118-
const std::string model = model_proto->SerializeAsString();
119-
// we have the serialized string, so we can release model proto to lower the peak memory consumption
120-
if (disable_cpu_fallback) model_proto.reset();
121-
exe_network_ = OVCore::Get()->CompileModel(model,
122-
hw_target,
123-
device_config,
124-
subgraph_context_.subgraph_name);
125-
} else { // For all other types use ov::ov_core read_model() to generate OV IR
126-
// followed by ov::ov_core compile_model()
127-
std::string model = model_proto->SerializeAsString();
128-
// Reset model proto only when cpu fallback is disabled or when the model has dynamic input shapes.
129-
// This is to avoid memory peak usage when the model is large.
130-
if (!subgraph_context.has_dynamic_input_shape && disable_cpu_fallback) {
131-
model_proto.reset();
79+
model_stream.reset();
80+
} catch (const char* msg) {
81+
ORT_THROW(msg);
82+
} // Delete stream after it is no longer needed
83+
} else {
84+
std::string model = model_proto->SerializeAsString();
85+
if (!subgraph_context.has_dynamic_input_shape) {
86+
model_proto.reset();
87+
}
88+
try {
89+
// SetOVDeviceConfiguration(device_config);
90+
if (!session_context_.has_external_weights &&
91+
!subgraph_context_.has_dynamic_input_shape &&
92+
!session_context_.so_context_enable &&
93+
session_context_.reshape.empty() &&
94+
!enable_causallm &&
95+
auto_unified_compile) {
96+
// Unified OV compile_model is efficient when ov model caching is enabled
97+
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
98+
// Inputs with static dimensions
99+
// Not enabled for models with external weights and when ep context is set.
100+
101+
exe_network_ = OVCore::Get()->CompileModel(model,
102+
hw_target,
103+
device_config,
104+
subgraph_context_.subgraph_name);
105+
} else { // For all other types use ov::ov_core read_model() to generate OV IR
106+
// followed by ov::ov_core compile_model()
107+
auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
108+
exe_network_ = OVCore::Get()->CompileModel(
109+
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name);
110+
}
111+
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
112+
} catch (const OnnxRuntimeException& ex) {
113+
std::string exception_str = ex.what();
114+
bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos &&
115+
!session_context_.so_disable_cpu_ep_fallback &&
116+
!subgraph_context_.is_ep_ctx_graph;
117+
#if defined(OPENVINO_DISABLE_NPU_FALLBACK)
118+
eligible_for_cpu_fallback = false;
119+
#endif
120+
if (eligible_for_cpu_fallback) {
121+
LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU."
122+
<< "Falling back to OV CPU for execution";
123+
session_context_.device_type = "CPU";
124+
session_context_.precision = "FP32";
125+
device_config.clear();
126+
SetOVDeviceConfiguration(device_config);
127+
try {
128+
// Recreate the model with CPU device type
129+
auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
130+
exe_network_ = OVCore::Get()->CompileModel(
131+
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name);
132+
} catch (std::string const& msg) {
133+
ORT_THROW(msg);
134+
}
135+
} else {
136+
ORT_THROW(ex.what());
132137
}
133-
auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
134-
exe_network_ = OVCore::Get()->CompileModel(
135-
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name);
136138
}
137-
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
138-
} catch (const char* msg) {
139-
ORT_THROW(msg);
140139
}
141-
142140
int num_infer_req = (session_context_.num_of_threads > 0) ? session_context_.num_of_threads : 1;
143141
std::function<void(OVInferRequestPtr)> initializer = [](OVInferRequestPtr) {};
144142
auto metadata = shared_context_.shared_weights.metadata;
@@ -385,6 +383,32 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
385383
device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads));
386384
}
387385

386+
void BasicBackend::SetOVDeviceConfiguration(ov::AnyMap& device_config) {
387+
PopulateConfigValue(device_config);
388+
389+
// Enable caching
390+
EnableCaching();
391+
392+
// Setting OpenCL queue throttling for GPU
393+
EnableGPUThrottling(device_config);
394+
395+
// Enable streams; default=1 unless overridden by user configuration
396+
EnableStreams();
397+
398+
// Set the inference_num_threads property of the CPU
399+
SetNumThreads(device_config);
400+
401+
auto npuw_status =
402+
std::any_of(device_config.begin(), device_config.end(), [&](const std::pair<std::string, ov::Any>& pair) {
403+
return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is<std::string>()) &&
404+
(pair.second.as<std::string>() == "YES");
405+
});
406+
407+
if (npuw_status) {
408+
LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation";
409+
}
410+
}
411+
388412
void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
389413
const ov::PartialShape& partial_shape) const {
390414
// Check if the number of dimensions matches

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class BasicBackend : public IBackend {
146146
void EnableGPUThrottling(ov::AnyMap& device_config);
147147
void EnableStreams();
148148
void SetNumThreads(ov::AnyMap& device_config);
149+
void SetOVDeviceConfiguration(ov::AnyMap& device_config);
149150
void ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
150151
const ov::PartialShape& partial_shape) const;
151152

0 commit comments

Comments
 (0)