@@ -36,42 +36,14 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
3636 if (ValidateSubgraph (const_outputs_map_))
3737 return ;
3838
39- // OV Config
39+ // Pre-requisite is provider_option "context" must be set
40+ auto auto_unified_compile = ((hw_target.find (" AUTO" ) == std::string::npos) ||
41+ (session_context_.OpenVINO_Version .at (0 ) >= 2024 &&
42+ session_context_.OpenVINO_Version .at (1 ) > 2 ));
4043 ov::AnyMap device_config;
41- PopulateConfigValue (device_config);
42-
43- // Enable caching
44- EnableCaching ();
45-
46- // Setting OpenCL queue throttling for GPU
47- EnableGPUThrottling (device_config);
48-
49- // Enable streams; default=1 unless overridden by user configuration
50- EnableStreams ();
51-
52- // Set the inference_num_threads property of the CPU
53- SetNumThreads (device_config);
54-
55- auto npuw_status =
56- std::any_of (device_config.begin (), device_config.end (), [&](const std::pair<std::string, ov::Any>& pair) {
57- return (pair.first .find (" NPU_USE_NPUW" ) != std::string::npos) && (pair.second .is <std::string>()) &&
58- (pair.second .as <std::string>() == " YES" );
59- });
60-
61- if (npuw_status) {
62- LOGS_DEFAULT (INFO) << log_tag << " NPUW Enabled during compilation" ;
63- }
64-
65- try {
66- // IO_BUFFER is enabled on GPU HW.
67- // Pre-requisite is provider_option "context" must be set
68- auto auto_unified_compile = ((hw_target.find (" AUTO" ) == std::string::npos) ||
69- (session_context_.OpenVINO_Version .at (0 ) >= 2024 &&
70- session_context_.OpenVINO_Version .at (1 ) > 2 ));
71- bool disable_cpu_fallback = !(hw_target.find (" NPU" ) != std::string::npos &&
72- !session_context_.so_disable_cpu_ep_fallback &&
73- !subgraph_context_.is_ep_ctx_graph );
74- if (subgraph_context_.is_ep_ctx_graph ) {
44+ SetOVDeviceConfiguration (device_config);
45+ if (subgraph_context_.is_ep_ctx_graph ) {
46+ try {
7547 if (subgraph_context_.is_ep_ctx_ovir_encapsulated ) {
7648 // model_file_path will use so_context_file_path if the onnx_model_path_name is not available,
7749 // especially in case of CreateSessionFormArray() where user must explicitly
@@ -104,41 +76,67 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
10476 device_config,
10577 subgraph_context_.subgraph_name );
10678 }
107- model_stream.reset (); // Delete stream after it is no longer needed
108- } else if (!session_context_.has_external_weights &&
109- !subgraph_context_.has_dynamic_input_shape &&
110- !session_context_.so_context_enable &&
111- session_context_.reshape .empty () &&
112- !enable_causallm &&
113- auto_unified_compile) {
114- // Unified OV compile_model is efficient when ov model caching is enabled
115- // Unified OV compile_model API is supported with AUTO from version 2024.3 and above
116- // Inputs with static dimensions
117- // Not enabled for models with external weights and when ep context is set.
118- const std::string model = model_proto->SerializeAsString ();
119- // we have the serialized string, so we can release model proto to lower the peak memory consumption
120- if (disable_cpu_fallback) model_proto.reset ();
121- exe_network_ = OVCore::Get ()->CompileModel (model,
122- hw_target,
123- device_config,
124- subgraph_context_.subgraph_name );
125- } else { // For all other types use ov::ov_core read_model() to generate OV IR
126- // followed by ov::ov_core compile_model()
127- std::string model = model_proto->SerializeAsString ();
128- // Reset model proto only when cpu fallback is disabled or when the model has dynamic input shapes.
129- // This is to avoid memory peak usage when the model is large.
130- if (!subgraph_context.has_dynamic_input_shape && disable_cpu_fallback) {
131- model_proto.reset ();
79+ model_stream.reset ();
80+ } catch (const char * msg) {
81+ ORT_THROW (msg);
82+ } // Delete stream after it is no longer needed
83+ } else {
84+ std::string model = model_proto->SerializeAsString ();
85+ if (!subgraph_context.has_dynamic_input_shape ) {
86+ model_proto.reset ();
87+ }
88+ try {
89+ // SetOVDeviceConfiguration(device_config);
90+ if (!session_context_.has_external_weights &&
91+ !subgraph_context_.has_dynamic_input_shape &&
92+ !session_context_.so_context_enable &&
93+ session_context_.reshape .empty () &&
94+ !enable_causallm &&
95+ auto_unified_compile) {
96+ // Unified OV compile_model is efficient when ov model caching is enabled
97+ // Unified OV compile_model API is supported with AUTO from version 2024.3 and above
98+ // Inputs with static dimensions
99+ // Not enabled for models with external weights and when ep context is set.
100+
101+ exe_network_ = OVCore::Get ()->CompileModel (model,
102+ hw_target,
103+ device_config,
104+ subgraph_context_.subgraph_name );
105+ } else { // For all other types use ov::ov_core read_model() to generate OV IR
106+ // followed by ov::ov_core compile_model()
107+ auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
108+ exe_network_ = OVCore::Get ()->CompileModel (
109+ ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
110+ }
111+ LOGS_DEFAULT (INFO) << log_tag << " Loaded model to the plugin" ;
112+ } catch (const OnnxRuntimeException& ex) {
113+ std::string exception_str = ex.what ();
114+ bool eligible_for_cpu_fallback = session_context_.device_type .find (" NPU" ) != std::string::npos &&
115+ !session_context_.so_disable_cpu_ep_fallback &&
116+ !subgraph_context_.is_ep_ctx_graph ;
117+ #if defined(OPENVINO_DISABLE_NPU_FALLBACK)
118+ eligible_for_cpu_fallback = false ;
119+ #endif
120+ if (eligible_for_cpu_fallback) {
121+ LOGS_DEFAULT (WARNING) << " Model compilation failed at OV NPU."
122+ << " Falling back to OV CPU for execution" ;
123+ session_context_.device_type = " CPU" ;
124+ session_context_.precision = " FP32" ;
125+ device_config.clear ();
126+ SetOVDeviceConfiguration (device_config);
127+ try {
128+ // Recreate the model with CPU device type
129+ auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
130+ exe_network_ = OVCore::Get ()->CompileModel (
131+ ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
132+ } catch (std::string const & msg) {
133+ ORT_THROW (msg);
134+ }
135+ } else {
136+ ORT_THROW (ex.what ());
132137 }
133- auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
134- exe_network_ = OVCore::Get ()->CompileModel (
135- ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
136138 }
137- LOGS_DEFAULT (INFO) << log_tag << " Loaded model to the plugin" ;
138- } catch (const char * msg) {
139- ORT_THROW (msg);
140139 }
141-
142140 int num_infer_req = (session_context_.num_of_threads > 0 ) ? session_context_.num_of_threads : 1 ;
143141 std::function<void (OVInferRequestPtr)> initializer = [](OVInferRequestPtr) {};
144142 auto metadata = shared_context_.shared_weights .metadata ;
@@ -385,6 +383,32 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
385383 device_config.emplace (ov::inference_num_threads (session_context_.num_of_threads ));
386384}
387385
386+ void BasicBackend::SetOVDeviceConfiguration (ov::AnyMap& device_config) {
387+ PopulateConfigValue (device_config);
388+
389+ // Enable caching
390+ EnableCaching ();
391+
392+ // Setting OpenCL queue throttling for GPU
393+ EnableGPUThrottling (device_config);
394+
395+ // Enable streams; default=1 unless overridden by user configuration
396+ EnableStreams ();
397+
398+ // Set the inference_num_threads property of the CPU
399+ SetNumThreads (device_config);
400+
401+ auto npuw_status =
402+ std::any_of (device_config.begin (), device_config.end (), [&](const std::pair<std::string, ov::Any>& pair) {
403+ return (pair.first .find (" NPU_USE_NPUW" ) != std::string::npos) && (pair.second .is <std::string>()) &&
404+ (pair.second .as <std::string>() == " YES" );
405+ });
406+
407+ if (npuw_status) {
408+ LOGS_DEFAULT (INFO) << log_tag << " NPUW Enabled during compilation" ;
409+ }
410+ }
411+
388412void BasicBackend::ValidateOrtDimsAgainstPartialShape (const std::vector<int64_t >& ort_dims,
389413 const ov::PartialShape& partial_shape) const {
390414 // Check if the number of dimensions matches
0 commit comments