Skip to content

Commit

Permalink
refactor!: reduce allocations in run; make SessionOutputs not a map
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 5, 2024
1 parent 552727e commit 8a16adb
Show file tree
Hide file tree
Showing 5 changed files with 383 additions and 65 deletions.
10 changes: 8 additions & 2 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,17 @@ impl IoBinding {
Some(Arc::clone(&self.session))
)
}
});
})
.collect::<Vec<_>>();

// output values will be freed when the `Value`s in `SessionOutputs` drop

Ok(SessionOutputs::new_backed(self.output_names.iter().map(String::as_str), output_values, &self.session.allocator, output_values_ptr.cast()))
Ok(SessionOutputs::new_backed(
self.output_names.iter().map(String::as_str).collect(),
output_values,
&self.session.allocator,
output_values_ptr.cast()
))
} else {
Ok(SessionOutputs::new_empty())
}
Expand Down
18 changes: 5 additions & 13 deletions src/session/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
use ort_sys::{OrtStatus, c_void};

use crate::{
error::{Result, assert_non_null_pointer},
error::Result,
session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner},
value::Value
};
Expand Down Expand Up @@ -138,17 +138,9 @@ crate::extern_system_fn! {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };

// Reconvert name ptrs to CString so drop impl is called and memory is freed
drop(
ctx.input_name_ptrs
.into_iter()
.chain(ctx.output_name_ptrs)
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }
})
.collect::<Result<Vec<_>>>()
.expect("Input name should not be null")
);
for p in ctx.input_name_ptrs {
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
}

if let Err(e) = crate::error::status_to_result(status) {
ctx.inner.emplace_value(Err(e));
Expand All @@ -164,7 +156,7 @@ crate::extern_system_fn! {
})
.collect();

ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names.into_iter(), outputs)));
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names, outputs)));
ctx.inner.wake();
}
}
17 changes: 5 additions & 12 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,11 @@ impl Session {
.collect();

// Reconvert name ptrs to CString so drop impl is called and memory is freed
drop(
input_names_ptr
.into_iter()
.chain(output_names_ptr.into_iter())
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }
})
.collect::<Result<Vec<_>>>()?
);

Ok(SessionOutputs::new(output_names.into_iter(), outputs))
for p in input_names_ptr.into_iter().chain(output_names_ptr.into_iter()) {
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
}

Ok(SessionOutputs::new(output_names, outputs))
}

/// Asynchronously run input data through the ONNX graph, performing inference.
Expand Down
Loading

0 comments on commit 8a16adb

Please sign in to comment.