Skip to content

Commit

Permalink
feat: many a feature
Browse files Browse the repository at this point in the history
I did that thing again!

Features in this commit:
- `ThreadManager` allows you to define custom thread creation functions for environments & sessions.
- Sessions can now opt-out of using the environment's global thread pool.
- Implemented the safe `ShapeInferenceContext` wrapper for custom operators.
- Prepacked weights allow the CPU execution provider to share one allocation for identical weights between sessions.
- Customize workload type to prioritize efficiency; useful for background tasks.
- Configurable per-session log identifiers
- Dynamic dimension overrides

Breaking changes:
- `EnvironmentGlobalThreadPoolOptions` is now `GlobalThreadPoolOptions` and uses the builder pattern instead of exposed struct fields.
  • Loading branch information
decahedron1 committed Nov 19, 2024
1 parent 7819d56 commit 87577ef
Show file tree
Hide file tree
Showing 12 changed files with 536 additions and 54 deletions.
8 changes: 8 additions & 0 deletions examples/custom-ops/examples/custom-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ impl Operator for CustomOpOne {
fn outputs() -> Vec<OperatorOutput> {
vec![OperatorOutput::required(TensorElementType::Float32)]
}

fn get_infer_shape_function() -> Option<Box<ort::operator::InferShapeFn>> {
Some(Box::new(|ctx| {
let inputs = ctx.inputs();
ctx.set_output(0, &inputs[0])?;
Ok(())
}))
}
}

impl Kernel for CustomOpOneKernel {
Expand Down
165 changes: 134 additions & 31 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
//! ```
use std::{
any::Any,
ffi::{self, CStr, CString},
os::raw::c_void,
ptr::{self, NonNull},
sync::{Arc, RwLock}
};
Expand Down Expand Up @@ -47,7 +49,8 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(No
pub struct Environment {
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
ptr: NonNull<ort_sys::OrtEnv>,
pub(crate) has_global_threadpool: bool
pub(crate) has_global_threadpool: bool,
_thread_manager: Option<Box<dyn Any>>
}

unsafe impl Send for Environment {}
Expand Down Expand Up @@ -83,20 +86,134 @@ pub fn get_environment() -> Result<Arc<Environment>> {
}
}

#[derive(Debug, Default, Clone)]
pub struct EnvironmentGlobalThreadPoolOptions {
pub inter_op_parallelism: Option<i32>,
pub intra_op_parallelism: Option<i32>,
pub spin_control: Option<bool>,
pub intra_op_thread_affinity: Option<String>
#[derive(Debug)]
pub struct GlobalThreadPoolOptions {
ptr: *mut ort_sys::OrtThreadingOptions,
thread_manager: Option<Box<dyn Any>>
}

impl Default for GlobalThreadPoolOptions {
fn default() -> Self {
let mut ptr = ptr::null_mut();
ortsys![unsafe CreateThreadingOptions(&mut ptr)];
Self { ptr, thread_manager: None }
}
}

impl GlobalThreadPoolOptions {
pub fn with_inter_threads(mut self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetGlobalInterOpNumThreads(self.ptr_mut(), num_threads as _)?];
Ok(self)
}

pub fn with_intra_threads(mut self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetGlobalIntraOpNumThreads(self.ptr_mut(), num_threads as _)?];
Ok(self)
}

pub fn with_spin_control(mut self, spin_control: bool) -> Result<Self> {
ortsys![unsafe SetGlobalSpinControl(self.ptr_mut(), if spin_control { 1 } else { 0 })?];
Ok(self)
}

pub fn with_intra_affinity(mut self, affinity: impl AsRef<str>) -> Result<Self> {
let affinity = CString::new(affinity.as_ref())?;
ortsys![unsafe SetGlobalIntraOpThreadAffinity(self.ptr_mut(), affinity.as_ptr())?];
Ok(self)
}

pub fn with_flush_to_zero(mut self) -> Result<Self> {
ortsys![unsafe SetGlobalDenormalAsZero(self.ptr_mut())?];
Ok(self)
}

pub fn with_thread_manager<T: ThreadManager + Any + 'static>(mut self, manager: T) -> Result<Self> {
let mut manager = Box::new(manager);
ortsys![unsafe SetGlobalCustomThreadCreationOptions(self.ptr_mut(), (&mut *manager as *mut T).cast())?];
ortsys![unsafe SetGlobalCustomCreateThreadFn(self.ptr_mut(), Some(thread_create::<T>))?];
ortsys![unsafe SetGlobalCustomJoinThreadFn(self.ptr_mut(), Some(thread_join::<T>))?];
self.thread_manager = Some(manager as Box<dyn Any>);
Ok(self)
}
}

impl AsPointer for GlobalThreadPoolOptions {
type Sys = ort_sys::OrtThreadingOptions;

fn ptr(&self) -> *const Self::Sys {
self.ptr
}
}

impl Drop for GlobalThreadPoolOptions {
fn drop(&mut self) {
ortsys![unsafe ReleaseThreadingOptions(self.ptr)];
}
}

pub struct ThreadWorker {
data: *mut c_void,
worker: ort_sys::OrtThreadWorkerFn
}

unsafe impl Send for ThreadWorker {}

impl ThreadWorker {
pub fn work(self) {
unsafe { self.worker.unwrap_unchecked()(self.data) }
}
}

pub trait ThreadManager {
type Thread;

fn create(&mut self, worker: ThreadWorker) -> crate::Result<Self::Thread>;

fn join(thread: Self::Thread) -> crate::Result<()>;
}

pub(crate) unsafe extern "C" fn thread_create<T: ThreadManager + Any>(
ort_custom_thread_creation_options: *mut c_void,
ort_thread_worker_fn: ort_sys::OrtThreadWorkerFn,
ort_worker_fn_param: *mut c_void
) -> ort_sys::OrtCustomThreadHandle {
let thread_worker = ThreadWorker {
data: ort_worker_fn_param,
worker: ort_thread_worker_fn
};

let res = std::panic::catch_unwind(|| {
let manager = unsafe { &mut *ort_custom_thread_creation_options.cast::<T>() };
<T as ThreadManager>::create(manager, thread_worker)
});
match res {
Ok(Ok(thread)) => (Box::leak(Box::new(thread)) as *mut <T as ThreadManager>::Thread)
.cast_const()
.cast::<ort_sys::OrtCustomHandleType>(),
Ok(Err(e)) => {
tracing::error!("Failed to create thread using manager: {e}");
ptr::null()
}
Err(e) => {
tracing::error!("Thread manager panicked: {e:?}");
ptr::null()
}
}
}

pub(crate) unsafe extern "C" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
let handle = Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<<T as ThreadManager>::Thread>());
if let Err(e) = <T as ThreadManager>::join(*handle) {
tracing::error!("Failed to join thread using manager: {e}");
}
}

/// Struct used to build an [`Environment`]; see [`crate::init`].
pub struct EnvironmentBuilder {
name: String,
telemetry: bool,
execution_providers: Vec<ExecutionProviderDispatch>,
global_thread_pool_options: Option<EnvironmentGlobalThreadPoolOptions>
global_thread_pool_options: Option<GlobalThreadPoolOptions>
}

impl EnvironmentBuilder {
Expand Down Expand Up @@ -153,48 +270,33 @@ impl EnvironmentBuilder {

/// Enables the global thread pool for this environment.
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self {
pub fn with_global_thread_pool(mut self, options: GlobalThreadPoolOptions) -> Self {
self.global_thread_pool_options = Some(options);
self
}

/// Commit the environment configuration and set the global environment.
pub fn commit(self) -> Result<Arc<Environment>> {
let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options {
let (env_ptr, thread_manager, has_global_threadpool) = if let Some(mut thread_pool_options) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());

let mut thread_options: *mut ort_sys::OrtThreadingOptions = std::ptr::null_mut();
ortsys![unsafe CreateThreadingOptions(&mut thread_options)?; nonNull(thread_options)];
if let Some(inter_op_parallelism) = global_thread_pool.inter_op_parallelism {
ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism)?];
}
if let Some(intra_op_parallelism) = global_thread_pool.intra_op_parallelism {
ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism)?];
}
if let Some(spin_control) = global_thread_pool.spin_control {
ortsys![unsafe SetGlobalSpinControl(thread_options, i32::from(spin_control))?];
}
if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity {
let cstr = CString::new(intra_op_thread_affinity).unwrap_or_else(|_| unreachable!());
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr())?];
}

ortsys![
unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
thread_options,
thread_pool_options.ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
(env_ptr, true)

let thread_manager = thread_pool_options.thread_manager.take();
(env_ptr, thread_manager, true)
} else {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
Expand All @@ -211,7 +313,7 @@ impl EnvironmentBuilder {
)?;
nonNull(env_ptr)
];
(env_ptr, false)
(env_ptr, None, false)
};
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");

Expand All @@ -230,7 +332,8 @@ impl EnvironmentBuilder {
execution_providers: self.execution_providers,
// we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call
ptr: unsafe { NonNull::new_unchecked(env_ptr) },
has_global_threadpool
has_global_threadpool,
_thread_manager: thread_manager
});
env_lock.replace(Arc::clone(&env));

Expand Down
9 changes: 6 additions & 3 deletions src/operator/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

use super::{
DummyOperator, Operator,
DummyOperator, Operator, ShapeInferenceContext,
io::InputOutputCharacteristic,
kernel::{Kernel, KernelAttributes, KernelContext}
};
Expand Down Expand Up @@ -203,8 +203,11 @@ impl<O: Operator> BoundOperator<O> {
}

extern_system_fn! {
pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status()
pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, ctx: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
let mut ctx = ShapeInferenceContext {
ptr: ctx
};
O::get_infer_shape_function().expect("missing infer shape function")(&mut ctx).into_status()
}
}
}
Expand Down
49 changes: 46 additions & 3 deletions src/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ use self::{
io::{OperatorInput, OperatorOutput},
kernel::{DummyKernel, Kernel, KernelAttributes}
};
use crate::{AsPointer, error::Result, ortsys};

pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>;
use crate::{
AsPointer, Error,
error::Result,
ortsys,
value::{ValueType, r#type::extract_data_type_from_tensor_info}
};

/// A custom operator descriptor, which describes the expected inputs & outputs of a graph operator.
///
Expand Down Expand Up @@ -84,6 +87,46 @@ impl Operator for DummyOperator {
}
}

pub type InferShapeFn = dyn FnMut(&mut ShapeInferenceContext) -> crate::Result<()> + 'static;

pub struct ShapeInferenceContext {
ptr: *mut ort_sys::OrtShapeInferContext
}

impl ShapeInferenceContext {
pub fn inputs(&self) -> Vec<ValueType> {
let mut count = 0;
ortsys![unsafe ShapeInferContext_GetInputCount(self.ptr(), &mut count).expect("failed to get input count")];

let mut tys = Vec::with_capacity(count);
for i in 0..count {
let mut ty_info = ptr::null_mut();
ortsys![unsafe ShapeInferContext_GetInputTypeShape(self.ptr(), i, &mut ty_info).expect("failed to get info type")];
tys.push(unsafe { extract_data_type_from_tensor_info(ty_info) });
}
tys
}

pub fn set_output(&mut self, idx: usize, ty: &ValueType) -> Result<()> {
match ty.to_tensor_type_info() {
Some(ty_ptr) => {
ortsys![unsafe ShapeInferContext_SetOutputTypeShape(self.ptr(), idx, ty_ptr)?];
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(ty_ptr)];
Ok(())
}
None => Err(Error::new("only tensors are supported"))
}
}
}

impl AsPointer for ShapeInferenceContext {
type Sys = ort_sys::OrtShapeInferContext;

fn ptr(&self) -> *const Self::Sys {
self.ptr
}
}

pub struct OperatorDomain {
ptr: NonNull<ort_sys::OrtCustomOpDomain>,
_name: CString,
Expand Down
Loading

0 comments on commit 87577ef

Please sign in to comment.