-
-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: session builder optimization options
- Loading branch information
1 parent
e31720d
commit bfa791d
Showing
5 changed files
with
419 additions
and
286 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
#[cfg(feature = "fetch-models")] | ||
use std::fmt::Write; | ||
use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc}; | ||
|
||
use super::SessionBuilder; | ||
#[cfg(feature = "fetch-models")] | ||
use crate::error::FetchModelError; | ||
use crate::{ | ||
environment::get_environment, | ||
error::{Error, Result}, | ||
execution_providers::apply_execution_providers, | ||
memory::Allocator, | ||
ortsys, | ||
session::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner} | ||
}; | ||
|
||
impl SessionBuilder { | ||
/// Downloads a pre-trained ONNX model from the given URL and builds the session. | ||
#[cfg(feature = "fetch-models")] | ||
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] | ||
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> { | ||
let mut download_dir = ort_sys::internal::dirs::cache_dir() | ||
.expect("could not determine cache directory") | ||
.join("models"); | ||
if std::fs::create_dir_all(&download_dir).is_err() { | ||
download_dir = std::env::current_dir().expect("Failed to obtain current working directory"); | ||
} | ||
|
||
let url = model_url.as_ref(); | ||
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| { | ||
let _ = write!(&mut s, "{:02x}", b); | ||
s | ||
}); | ||
let model_filepath = download_dir.join(model_filename); | ||
let downloaded_path = if model_filepath.exists() { | ||
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); | ||
model_filepath | ||
} else { | ||
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); | ||
|
||
let resp = ureq::get(url).call().map_err(Box::new).map_err(FetchModelError::FetchError)?; | ||
|
||
let len = resp | ||
.header("Content-Length") | ||
.and_then(|s| s.parse::<usize>().ok()) | ||
.expect("Missing Content-Length header"); | ||
tracing::info!(len, "Downloading {} bytes", len); | ||
|
||
let mut reader = resp.into_reader(); | ||
|
||
let f = std::fs::File::create(&model_filepath).expect("Failed to create model file"); | ||
let mut writer = std::io::BufWriter::new(f); | ||
|
||
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(FetchModelError::IoError)?; | ||
if bytes_io_count == len as u64 { | ||
model_filepath | ||
} else { | ||
return Err(FetchModelError::CopyError { | ||
expected: len as u64, | ||
io: bytes_io_count | ||
} | ||
.into()); | ||
} | ||
}; | ||
|
||
self.commit_from_file(downloaded_path) | ||
} | ||
|
||
/// Loads an ONNX model from a file and builds the session. | ||
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session> | ||
where | ||
P: AsRef<Path> | ||
{ | ||
let model_filepath = model_filepath_ref.as_ref(); | ||
if !model_filepath.exists() { | ||
return Err(Error::FileDoesNotExist { | ||
filename: model_filepath.to_path_buf() | ||
}); | ||
} | ||
|
||
let model_path = crate::util::path_to_os_char(model_filepath); | ||
|
||
let env = get_environment()?; | ||
apply_execution_providers(&self, env.execution_providers.iter().cloned())?; | ||
|
||
if env.has_global_threadpool { | ||
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; | ||
} | ||
|
||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); | ||
ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)]; | ||
|
||
let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; | ||
|
||
let allocator = match &self.memory_info { | ||
Some(info) => { | ||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); | ||
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; | ||
unsafe { Allocator::from_raw_unchecked(allocator_ptr) } | ||
} | ||
None => Allocator::default() | ||
}; | ||
|
||
// Extract input and output properties | ||
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; | ||
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; | ||
let inputs = (0..num_input_nodes) | ||
.map(|i| dangerous::extract_input(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Input>>>()?; | ||
let outputs = (0..num_output_nodes) | ||
.map(|i| dangerous::extract_output(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Output>>>()?; | ||
|
||
let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>); | ||
#[cfg(feature = "operator-libraries")] | ||
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>)); | ||
let extras: Vec<Box<dyn Any>> = extras.collect(); | ||
|
||
Ok(Session { | ||
inner: Arc::new(SharedSessionInner { | ||
session_ptr, | ||
allocator, | ||
_extras: extras, | ||
_environment: env | ||
}), | ||
inputs, | ||
outputs | ||
}) | ||
} | ||
|
||
/// Load an ONNX graph from memory and commit the session | ||
/// For `.ort` models, we enable `session.use_ort_model_bytes_directly`. | ||
/// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array). | ||
/// | ||
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that | ||
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros). | ||
pub fn commit_from_memory_directly(mut self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> { | ||
// Enable zero-copy deserialization for models in `.ort` format. | ||
self.add_config_entry("session.use_ort_model_bytes_directly", "1")?; | ||
self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?; | ||
|
||
let session = self.commit_from_memory(model_bytes)?; | ||
|
||
Ok(InMemorySession { session, phantom: PhantomData }) | ||
} | ||
|
||
/// Load an ONNX graph from memory and commit the session. | ||
pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result<Session> { | ||
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); | ||
|
||
let env = get_environment()?; | ||
apply_execution_providers(&self, env.execution_providers.iter().cloned())?; | ||
|
||
if env.has_global_threadpool { | ||
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; | ||
} | ||
|
||
let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>(); | ||
let model_data_length = model_bytes.len(); | ||
ortsys![ | ||
unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; | ||
nonNull(session_ptr) | ||
]; | ||
|
||
let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) }; | ||
|
||
let allocator = match &self.memory_info { | ||
Some(info) => { | ||
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); | ||
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; | ||
unsafe { Allocator::from_raw_unchecked(allocator_ptr) } | ||
} | ||
None => Allocator::default() | ||
}; | ||
|
||
// Extract input and output properties | ||
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; | ||
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; | ||
let inputs = (0..num_input_nodes) | ||
.map(|i| dangerous::extract_input(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Input>>>()?; | ||
let outputs = (0..num_output_nodes) | ||
.map(|i| dangerous::extract_output(session_ptr, &allocator, i)) | ||
.collect::<Result<Vec<Output>>>()?; | ||
|
||
let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>); | ||
#[cfg(feature = "operator-libraries")] | ||
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>)); | ||
let extras: Vec<Box<dyn Any>> = extras.collect(); | ||
|
||
let session = Session { | ||
inner: Arc::new(SharedSessionInner { | ||
session_ptr, | ||
allocator, | ||
_extras: extras, | ||
_environment: env | ||
}), | ||
inputs, | ||
outputs | ||
}; | ||
Ok(session) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
use super::SessionBuilder; | ||
use crate::Result; | ||
|
||
// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h | ||
|
||
impl SessionBuilder { | ||
/// Enable/disable the usage of prepacking. | ||
/// | ||
/// This option is **enabled** by default. | ||
pub fn with_prepacking(mut self, enable: bool) -> Result<Self> { | ||
self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" })?; | ||
Ok(self) | ||
} | ||
|
||
/// Use allocators from the registered environment. | ||
/// | ||
/// This option is **disabled** by default. | ||
pub fn with_env_allocators(mut self) -> Result<Self> { | ||
self.add_config_entry("session.use_env_allocators", "1")?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable flush-to-zero and denormal-as-zero. | ||
/// | ||
/// This option is **disabled** by default, as it may hurt model accuracy. | ||
pub fn with_denormal_as_zero(mut self) -> Result<Self> { | ||
self.add_config_entry("session.set_denormal_as_zero", "1")?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable/disable fusion for quantized models in QDQ (QuantizeLinear/DequantizeLinear) format. | ||
/// | ||
/// This option is **enabled** by default for all EPs except DirectML. | ||
pub fn with_quant_qdq(mut self, enable: bool) -> Result<Self> { | ||
self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" })?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable/disable the optimization step removing double QDQ nodes. | ||
/// | ||
/// This option is **enabled** by default. | ||
pub fn with_double_qdq_remover(mut self, enable: bool) -> Result<Self> { | ||
self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" })?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed. | ||
/// | ||
/// This option is **disabled** by default. | ||
pub fn with_qdq_cleanup(mut self) -> Result<Self> { | ||
self.add_config_entry("session.enable_quant_qdq_cleanup", "1")?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable fast GELU approximation. | ||
/// | ||
/// This option is **disabled** by default, as it may hurt accuracy. | ||
pub fn with_approximate_gelu(mut self) -> Result<Self> { | ||
self.add_config_entry("optimization.enable_gelu_approximation", "1")?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable/disable ahead-of-time function inlining. | ||
/// | ||
/// This option is **enabled** by default. | ||
pub fn with_aot_inlining(mut self, enable: bool) -> Result<Self> { | ||
self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" })?; | ||
Ok(self) | ||
} | ||
|
||
/// Accepts a comma-separated list of optimizers to disable. | ||
pub fn with_disabled_optimizers(mut self, optimizers: &str) -> Result<Self> { | ||
self.add_config_entry("optimization.disable_specified_optimizers", optimizers)?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable using device allocator for allocating initialized tensor memory. | ||
/// | ||
/// This option is **disabled** by default. | ||
pub fn with_device_allocator_for_initializers(mut self) -> Result<Self> { | ||
self.add_config_entry("session.use_device_allocator_for_initializers", "1")?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable/disable allowing the inter-op threads to spin for a short period before blocking. | ||
/// | ||
/// This option is **enabled** by defualt. | ||
pub fn with_inter_op_spinning(mut self, enable: bool) -> Result<Self> { | ||
self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" })?; | ||
Ok(self) | ||
} | ||
|
||
/// Enable/disable allowing the intra-op threads to spin for a short period before blocking. | ||
/// | ||
/// This option is **enabled** by defualt. | ||
pub fn with_intra_op_spinning(mut self, enable: bool) -> Result<Self> { | ||
self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" })?; | ||
Ok(self) | ||
} | ||
} |
Oops, something went wrong.