Skip to content

Commit

Permalink
feat: session builder optimization options
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Aug 31, 2024
1 parent e31720d commit bfa791d
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 286 deletions.
5 changes: 5 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ pub enum Error {
/// Error occurred when creating ONNX session options.
#[error("Failed to create ONNX Runtime session options: {0}")]
CreateSessionOptions(ErrorInternal),
/// Failed to enable `onnxruntime-extensions` for session.
#[error("Failed to enable `onnxruntime-extensions`: {0}")]
EnableExtensions(ErrorInternal),
#[error("Failed to add configuration entry to session builder: {0}")]
AddSessionConfigEntry(ErrorInternal),
/// Error occurred when creating an allocator from a [`crate::MemoryInfo`] struct while building a session.
#[error("Failed to create allocator from memory info: {0}")]
CreateAllocator(ErrorInternal),
Expand Down
203 changes: 203 additions & 0 deletions src/session/builder/impl_commit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#[cfg(feature = "fetch-models")]
use std::fmt::Write;
use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc};

use super::SessionBuilder;
#[cfg(feature = "fetch-models")]
use crate::error::FetchModelError;
use crate::{
environment::get_environment,
error::{Error, Result},
execution_providers::apply_execution_providers,
memory::Allocator,
ortsys,
session::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}
};

impl SessionBuilder {
/// Downloads a pre-trained ONNX model from the given URL and builds the session.
#[cfg(feature = "fetch-models")]
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> {
let mut download_dir = ort_sys::internal::dirs::cache_dir()
.expect("could not determine cache directory")
.join("models");
if std::fs::create_dir_all(&download_dir).is_err() {
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
}

let url = model_url.as_ref();
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
let _ = write!(&mut s, "{:02x}", b);
s
});
let model_filepath = download_dir.join(model_filename);
let downloaded_path = if model_filepath.exists() {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
model_filepath
} else {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");

let resp = ureq::get(url).call().map_err(Box::new).map_err(FetchModelError::FetchError)?;

let len = resp
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.expect("Missing Content-Length header");
tracing::info!(len, "Downloading {} bytes", len);

let mut reader = resp.into_reader();

let f = std::fs::File::create(&model_filepath).expect("Failed to create model file");
let mut writer = std::io::BufWriter::new(f);

let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(FetchModelError::IoError)?;
if bytes_io_count == len as u64 {
model_filepath
} else {
return Err(FetchModelError::CopyError {
expected: len as u64,
io: bytes_io_count
}
.into());
}
};

self.commit_from_file(downloaded_path)
}

/// Loads an ONNX model from a file and builds the session.
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session>
where
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
if !model_filepath.exists() {
return Err(Error::FileDoesNotExist {
filename: model_filepath.to_path_buf()
});
}

let model_path = crate::util::path_to_os_char(model_filepath);

let env = get_environment()?;
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;

if env.has_global_threadpool {
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
}

let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];

let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };

let allocator = match &self.memory_info {
Some(info) => {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)];
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
}
None => Allocator::default()
};

// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, &allocator, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
.collect::<Result<Vec<Output>>>()?;

let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>);
#[cfg(feature = "operator-libraries")]
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>));
let extras: Vec<Box<dyn Any>> = extras.collect();

Ok(Session {
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_extras: extras,
_environment: env
}),
inputs,
outputs
})
}

/// Load an ONNX graph from memory and commit the session
/// For `.ort` models, we enable `session.use_ort_model_bytes_directly`.
/// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array).
///
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
pub fn commit_from_memory_directly(mut self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> {
// Enable zero-copy deserialization for models in `.ort` format.
self.add_config_entry("session.use_ort_model_bytes_directly", "1")?;
self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?;

let session = self.commit_from_memory(model_bytes)?;

Ok(InMemorySession { session, phantom: PhantomData })
}

/// Load an ONNX graph from memory and commit the session.
pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result<Session> {
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();

let env = get_environment()?;
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;

if env.has_global_threadpool {
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
}

let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>();
let model_data_length = model_bytes.len();
ortsys![
unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession;
nonNull(session_ptr)
];

let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };

let allocator = match &self.memory_info {
Some(info) => {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)];
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
}
None => Allocator::default()
};

// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, &allocator, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
.collect::<Result<Vec<Output>>>()?;

let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>);
#[cfg(feature = "operator-libraries")]
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>));
let extras: Vec<Box<dyn Any>> = extras.collect();

let session = Session {
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_extras: extras,
_environment: env
}),
inputs,
outputs
};
Ok(session)
}
}
100 changes: 100 additions & 0 deletions src/session/builder/impl_config_keys.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use super::SessionBuilder;
use crate::Result;

// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

impl SessionBuilder {
/// Enable/disable the usage of prepacking.
///
/// This option is **enabled** by default.
pub fn with_prepacking(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" })?;
Ok(self)
}

/// Use allocators from the registered environment.
///
/// This option is **disabled** by default.
pub fn with_env_allocators(mut self) -> Result<Self> {
self.add_config_entry("session.use_env_allocators", "1")?;
Ok(self)
}

/// Enable flush-to-zero and denormal-as-zero.
///
/// This option is **disabled** by default, as it may hurt model accuracy.
pub fn with_denormal_as_zero(mut self) -> Result<Self> {
self.add_config_entry("session.set_denormal_as_zero", "1")?;
Ok(self)
}

/// Enable/disable fusion for quantized models in QDQ (QuantizeLinear/DequantizeLinear) format.
///
/// This option is **enabled** by default for all EPs except DirectML.
pub fn with_quant_qdq(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" })?;
Ok(self)
}

/// Enable/disable the optimization step removing double QDQ nodes.
///
/// This option is **enabled** by default.
pub fn with_double_qdq_remover(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" })?;
Ok(self)
}

/// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed.
///
/// This option is **disabled** by default.
pub fn with_qdq_cleanup(mut self) -> Result<Self> {
self.add_config_entry("session.enable_quant_qdq_cleanup", "1")?;
Ok(self)
}

/// Enable fast GELU approximation.
///
/// This option is **disabled** by default, as it may hurt accuracy.
pub fn with_approximate_gelu(mut self) -> Result<Self> {
self.add_config_entry("optimization.enable_gelu_approximation", "1")?;
Ok(self)
}

/// Enable/disable ahead-of-time function inlining.
///
/// This option is **enabled** by default.
pub fn with_aot_inlining(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" })?;
Ok(self)
}

/// Accepts a comma-separated list of optimizers to disable.
pub fn with_disabled_optimizers(mut self, optimizers: &str) -> Result<Self> {
self.add_config_entry("optimization.disable_specified_optimizers", optimizers)?;
Ok(self)
}

/// Enable using device allocator for allocating initialized tensor memory.
///
/// This option is **disabled** by default.
pub fn with_device_allocator_for_initializers(mut self) -> Result<Self> {
self.add_config_entry("session.use_device_allocator_for_initializers", "1")?;
Ok(self)
}

/// Enable/disable allowing the inter-op threads to spin for a short period before blocking.
///
/// This option is **enabled** by defualt.
pub fn with_inter_op_spinning(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" })?;
Ok(self)
}

/// Enable/disable allowing the intra-op threads to spin for a short period before blocking.
///
/// This option is **enabled** by defualt.
pub fn with_intra_op_spinning(mut self, enable: bool) -> Result<Self> {
self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" })?;
Ok(self)
}
}
Loading

0 comments on commit bfa791d

Please sign in to comment.