From 8a16adbb6b2a8ff721152e6a27ac50b920214737 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 4 Nov 2024 20:10:58 -0600 Subject: [PATCH] refactor!: reduce allocations in `run`; make `SessionOutputs` not a map --- src/io_binding.rs | 10 +- src/session/async.rs | 18 +- src/session/mod.rs | 17 +- src/session/output.rs | 399 ++++++++++++++++++++++++++++++++++++---- src/training/trainer.rs | 4 +- 5 files changed, 383 insertions(+), 65 deletions(-) diff --git a/src/io_binding.rs b/src/io_binding.rs index 063d570..4011246 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -231,11 +231,17 @@ impl IoBinding { Some(Arc::clone(&self.session)) ) } - }); + }) + .collect::>(); // output values will be freed when the `Value`s in `SessionOutputs` drop - Ok(SessionOutputs::new_backed(self.output_names.iter().map(String::as_str), output_values, &self.session.allocator, output_values_ptr.cast())) + Ok(SessionOutputs::new_backed( + self.output_names.iter().map(String::as_str).collect(), + output_values, + &self.session.allocator, + output_values_ptr.cast() + )) } else { Ok(SessionOutputs::new_empty()) } diff --git a/src/session/async.rs b/src/session/async.rs index 770fc11..8e25449 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -12,7 +12,7 @@ use std::{ use ort_sys::{OrtStatus, c_void}; use crate::{ - error::{Result, assert_non_null_pointer}, + error::Result, session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner}, value::Value }; @@ -138,17 +138,9 @@ crate::extern_system_fn! { let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; // Reconvert name ptrs to CString so drop impl is called and memory is freed - drop( - ctx.input_name_ptrs - .into_iter() - .chain(ctx.output_name_ptrs) - .map(|p| { - assert_non_null_pointer(p, "c_char for CString")?; - unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } - }) - .collect::>>() - .expect("Input name should not be null") - ); + for p in ctx.input_name_ptrs { + drop(unsafe { CString::from_raw(p.cast_mut().cast()) }); + } if let Err(e) = crate::error::status_to_result(status) { ctx.inner.emplace_value(Err(e)); @@ -164,7 +156,7 @@ crate::extern_system_fn! { }) .collect(); - ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names.into_iter(), outputs))); + ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names, outputs))); ctx.inner.wake(); } } diff --git a/src/session/mod.rs b/src/session/mod.rs index 9c5fee0..15a5b1e 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -324,18 +324,11 @@ impl Session { .collect(); // Reconvert name ptrs to CString so drop impl is called and memory is freed - drop( - input_names_ptr - .into_iter() - .chain(output_names_ptr.into_iter()) - .map(|p| { - assert_non_null_pointer(p, "c_char for CString")?; - unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } - }) - .collect::>>()? - ); - - Ok(SessionOutputs::new(output_names.into_iter(), outputs)) + for p in input_names_ptr.into_iter().chain(output_names_ptr.into_iter()) { + drop(unsafe { CString::from_raw(p.cast_mut().cast()) }); + } + + Ok(SessionOutputs::new(output_names, outputs)) } /// Asynchronously run input data through the ONNX graph, performing inference. diff --git a/src/session/output.rs b/src/session/output.rs index ffb8577..f9df5f6 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -1,10 +1,14 @@ use std::{ - collections::BTreeMap, ffi::c_void, - ops::{Deref, DerefMut, Index} + iter::FusedIterator, + marker::PhantomData, + mem::ManuallyDrop, + ops::{Index, IndexMut}, + ptr, + sync::Arc }; -use crate::{memory::Allocator, value::DynValue}; +use crate::{ValueRef, ValueRefMut, memory::Allocator, value::DynValue}; /// The outputs returned by a [`crate::Session`] inference call. /// @@ -26,44 +30,172 @@ use crate::{memory::Allocator, value::DynValue}; /// ``` #[derive(Debug)] pub struct SessionOutputs<'r, 's> { - map: BTreeMap<&'r str, DynValue>, - idxs: Vec<&'r str>, + keys: Vec<&'r str>, + values: Vec, + effective_len: usize, backing_ptr: Option<(&'s Allocator, *mut c_void)> } unsafe impl Send for SessionOutputs<'_, '_> {} impl<'r, 's> SessionOutputs<'r, 's> { - pub(crate) fn new(output_names: impl Iterator + Clone, output_values: impl IntoIterator) -> Self { - let map = output_names.clone().zip(output_values).collect(); + pub(crate) fn new(output_names: Vec<&'r str>, output_values: Vec) -> Self { + debug_assert_eq!(output_names.len(), output_values.len()); Self { - map, - idxs: output_names.collect(), + effective_len: output_names.len(), + keys: output_names, + values: output_values, backing_ptr: None } } - pub(crate) fn new_backed( - output_names: impl Iterator + Clone, - output_values: impl IntoIterator, - allocator: &'s Allocator, - backing_ptr: *mut c_void - ) -> Self { - let map = output_names.clone().zip(output_values).collect(); + pub(crate) fn new_backed(output_names: Vec<&'r str>, output_values: Vec, allocator: &'s Allocator, backing_ptr: *mut c_void) -> Self { + debug_assert_eq!(output_names.len(), output_values.len()); Self { - map, - idxs: output_names.collect(), + effective_len: output_names.len(), + keys: output_names, + values: output_values, backing_ptr: Some((allocator, backing_ptr)) } } pub(crate) fn new_empty() -> Self { Self { - map: BTreeMap::new(), - idxs: Vec::new(), + effective_len: 0, + keys: Vec::new(), + values: Vec::new(), backing_ptr: None } } + + pub fn contains_key(&self, key: impl AsRef) -> bool { + let key = key.as_ref(); + assert!(!key.is_empty(), "output name cannot be empty"); + for k in &self.keys { + if &key == k { + return true; + } + } + false + } + + pub fn get(&self, key: impl AsRef) -> Option<&DynValue> { + let key = key.as_ref(); + assert!(!key.is_empty(), "output name cannot be empty"); + for (i, k) in self.keys.iter().enumerate() { + if &key == k { + return Some(&self.values[i]); + } + } + None + } + + pub fn get_mut(&mut self, key: impl AsRef) -> Option<&mut DynValue> { + let key = key.as_ref(); + assert!(!key.is_empty(), "output name cannot be empty"); + for (i, k) in self.keys.iter().enumerate() { + if &key == k { + return Some(&mut self.values[i]); + } + } + None + } + + pub fn remove(&mut self, key: impl AsRef) -> Option { + let key = key.as_ref(); + assert!(!key.is_empty(), "output name cannot be empty"); + for (i, k) in self.keys.iter_mut().enumerate() { + if &key == k { + *k = ""; + self.effective_len -= 1; + return Some(DynValue { + inner: Arc::clone(&self.values[i].inner), + _markers: PhantomData + }); + } + } + None + } + + #[inline(always)] + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.effective_len + } + + pub fn keys(&self) -> Keys<'_, 'r> { + Keys { + iter: self.keys.iter(), + effective_len: self.effective_len + } + } + + pub fn values(&self) -> Values<'_, 'r> { + Values { + key_iter: self.keys.iter(), + value_iter: self.values.iter(), + effective_len: self.effective_len + } + } + + pub fn values_mut(&mut self) -> ValuesMut<'_, 'r> { + ValuesMut { + key_iter: self.keys.iter(), + value_iter: self.values.iter_mut(), + effective_len: self.effective_len + } + } + + pub fn iter(&self) -> Iter<'_, 'r> { + Iter { + key_iter: self.keys.iter(), + value_iter: self.values.iter(), + effective_len: self.effective_len + } + } + + pub fn iter_mut(&mut self) -> IterMut<'_, 'r> { + IterMut { + key_iter: self.keys.iter(), + value_iter: self.values.iter_mut(), + effective_len: self.effective_len + } + } +} + +impl<'x, 'r> IntoIterator for &'x SessionOutputs<'r, '_> { + type IntoIter = Iter<'x, 'r>; + type Item = (&'r str, ValueRef<'x>); + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'x, 'r> IntoIterator for &'x mut SessionOutputs<'r, '_> { + type IntoIter = IterMut<'x, 'r>; + type Item = (&'r str, ValueRefMut<'x>); + + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} + +impl<'r, 's> IntoIterator for SessionOutputs<'r, 's> { + type IntoIter = IntoIter<'r, 's>; + type Item = (&'r str, DynValue); + + fn into_iter(self) -> Self::IntoIter { + let this = ManuallyDrop::new(self); + let keys = unsafe { ptr::read(&this.keys) }.into_iter(); + let values = unsafe { ptr::read(&this.values) }.into_iter(); + IntoIter { + keys, + values, + effective_len: this.effective_len, + backing_ptr: this.backing_ptr + } + } } impl Drop for SessionOutputs<'_, '_> { @@ -74,37 +206,232 @@ impl Drop for SessionOutputs<'_, '_> { } } -impl<'r> Deref for SessionOutputs<'r, '_> { - type Target = BTreeMap<&'r str, DynValue>; - - fn deref(&self) -> &Self::Target { - &self.map +impl Index<&str> for SessionOutputs<'_, '_> { + type Output = DynValue; + fn index(&self, key: &str) -> &Self::Output { + self.get(key).unwrap_or_else(|| panic!("no output named `{key}`")) } } -impl DerefMut for SessionOutputs<'_, '_> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.map +impl IndexMut<&str> for SessionOutputs<'_, '_> { + fn index_mut(&mut self, key: &str) -> &mut Self::Output { + self.get_mut(key).unwrap_or_else(|| panic!("no output named `{key}`")) } } -impl Index<&str> for SessionOutputs<'_, '_> { +impl Index for SessionOutputs<'_, '_> { type Output = DynValue; - fn index(&self, index: &str) -> &Self::Output { - self.map.get(index).expect("no entry found for key") + fn index(&self, key: String) -> &Self::Output { + self.get(&key).unwrap_or_else(|| panic!("no output named `{key}`")) } } -impl Index for SessionOutputs<'_, '_> { - type Output = DynValue; - fn index(&self, index: String) -> &Self::Output { - self.map.get(index.as_str()).expect("no entry found for key") +impl IndexMut for SessionOutputs<'_, '_> { + fn index_mut(&mut self, key: String) -> &mut Self::Output { + self.get_mut(&key).unwrap_or_else(|| panic!("no output named `{key}`")) } } impl Index for SessionOutputs<'_, '_> { type Output = DynValue; fn index(&self, index: usize) -> &Self::Output { - self.map.get(&self.idxs[index]).expect("no entry found for key") + if index > self.values.len() { + panic!("attempted to index output #{index} when there are only {} outputs", self.values.len()); + } + &self.values[index] + } +} + +impl IndexMut for SessionOutputs<'_, '_> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + if index > self.values.len() { + panic!("attempted to index output #{index} when there are only {} outputs", self.values.len()); + } + &mut self.values[index] + } +} + +pub struct Keys<'x, 'r> { + iter: std::slice::Iter<'x, &'r str>, + effective_len: usize +} + +impl<'r> Iterator for Keys<'_, 'r> { + type Item = &'r str; + + fn next(&mut self) -> Option { + loop { + match self.iter.next() { + None => return None, + Some(&"") => continue, + Some(x) => { + self.effective_len -= 1; + return Some(x); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.effective_len, Some(self.effective_len)) + } +} + +impl ExactSizeIterator for Keys<'_, '_> {} +impl FusedIterator for Keys<'_, '_> {} + +pub struct Values<'x, 'k> { + value_iter: std::slice::Iter<'x, DynValue>, + key_iter: std::slice::Iter<'x, &'k str>, + effective_len: usize +} + +impl<'x> Iterator for Values<'x, '_> { + type Item = ValueRef<'x>; + + fn next(&mut self) -> Option { + loop { + match self.key_iter.next() { + None => return None, + Some(&"") => continue, + Some(_) => { + self.effective_len -= 1; + return self.value_iter.next().map(DynValue::view); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.effective_len, Some(self.effective_len)) + } +} + +impl ExactSizeIterator for Values<'_, '_> {} +impl FusedIterator for Values<'_, '_> {} + +pub struct ValuesMut<'x, 'k> { + value_iter: std::slice::IterMut<'x, DynValue>, + key_iter: std::slice::Iter<'x, &'k str>, + effective_len: usize +} + +impl<'x> Iterator for ValuesMut<'x, '_> { + type Item = ValueRefMut<'x>; + + fn next(&mut self) -> Option { + loop { + match self.key_iter.next() { + None => return None, + Some(&"") => continue, + Some(_) => { + self.effective_len -= 1; + return self.value_iter.next().map(DynValue::view_mut); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.effective_len, Some(self.effective_len)) + } +} + +impl ExactSizeIterator for ValuesMut<'_, '_> {} +impl FusedIterator for ValuesMut<'_, '_> {} + +pub struct Iter<'x, 'k> { + value_iter: std::slice::Iter<'x, DynValue>, + key_iter: std::slice::Iter<'x, &'k str>, + effective_len: usize +} + +impl<'x, 'k> Iterator for Iter<'x, 'k> { + type Item = (&'k str, ValueRef<'x>); + + fn next(&mut self) -> Option { + loop { + match self.key_iter.next() { + None => return None, + Some(&"") => continue, + Some(key) => { + self.effective_len -= 1; + return self.value_iter.next().map(|v| (*key, v.view())); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.effective_len, Some(self.effective_len)) + } +} + +impl ExactSizeIterator for Iter<'_, '_> {} +impl FusedIterator for Iter<'_, '_> {} + +pub struct IterMut<'x, 'k> { + value_iter: std::slice::IterMut<'x, DynValue>, + key_iter: std::slice::Iter<'x, &'k str>, + effective_len: usize +} + +impl<'x, 'k> Iterator for IterMut<'x, 'k> { + type Item = (&'k str, ValueRefMut<'x>); + + fn next(&mut self) -> Option { + loop { + match self.key_iter.next() { + None => return None, + Some(&"") => continue, + Some(key) => { + self.effective_len -= 1; + return self.value_iter.next().map(|v| (*key, v.view_mut())); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.effective_len, Some(self.effective_len)) + } +} + +impl ExactSizeIterator for IterMut<'_, '_> {} +impl FusedIterator for IterMut<'_, '_> {} + +pub struct IntoIter<'r, 's> { + keys: std::vec::IntoIter<&'r str>, + values: std::vec::IntoIter, + effective_len: usize, + backing_ptr: Option<(&'s Allocator, *mut c_void)> +} + +impl<'r> Iterator for IntoIter<'r, '_> { + type Item = (&'r str, DynValue); + + fn next(&mut self) -> Option { + loop { + match self.keys.next() { + None => return None, + Some("") => continue, + Some(key) => { + self.effective_len -= 1; + return self.values.next().map(|v| (key, v)); + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.effective_len, Some(self.effective_len)) + } +} + +impl Drop for IntoIter<'_, '_> { + fn drop(&mut self) { + if let Some((allocator, ptr)) = self.backing_ptr { + unsafe { allocator.free(ptr) }; + } } } diff --git a/src/training/trainer.rs b/src/training/trainer.rs index 5468247..7422451 100644 --- a/src/training/trainer.rs +++ b/src/training/trainer.rs @@ -136,7 +136,7 @@ impl Trainer { }) .collect(); - Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + Ok(SessionOutputs::new(self.train_output_names.iter().map(String::as_str).collect(), outputs)) } pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( @@ -185,7 +185,7 @@ impl Trainer { }) .collect(); - Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + Ok(SessionOutputs::new(self.train_output_names.iter().map(String::as_str).collect(), outputs)) } pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> {