Skip to content

Commit

Permalink
fix: do not require 'static for run_async inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 21, 2024
1 parent c8bd1cc commit cf447cf
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
19 changes: 11 additions & 8 deletions examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ struct AppState {
fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, mut tokens: Vec<i64>, gen_tokens: usize) -> impl Stream<Item = ort::Result<Event>> + Send {
async_stream_lite::try_async_stream(|yielder| async move {
for _ in 0..gen_tokens {
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let outputs = session.run_async(ort::inputs![input])?.await?;
let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?;

// Collect and sort logits
let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize);
let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
let probabilities = {
let input = TensorRef::from_array_view((vec![1, 1, tokens.len() as i64], tokens.as_slice()))?;
let outputs = session.run_async(ort::inputs![input])?.await?;
let (dim, probabilities) = outputs["output1"].try_extract_raw_tensor()?;

// Collect and sort logits
let (seq_len, vocab_size) = (dim[2] as usize, dim[3] as usize);
let mut probabilities: Vec<(usize, f32)> = probabilities[(seq_len - 1) * vocab_size..].iter().copied().enumerate().collect();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
probabilities
};

// Sample using top-k sampling
let token = {
Expand Down
21 changes: 12 additions & 9 deletions src/session/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{
cell::UnsafeCell,
ffi::{CString, c_char},
future::Future,
marker::PhantomData,
ops::Deref,
pin::Pin,
ptr::NonNull,
Expand Down Expand Up @@ -81,23 +82,25 @@ impl<O: SelectedOutputMarker> Deref for RunOptionsRef<'_, O> {
}
}

pub struct InferenceFut<'s, 'r, O: SelectedOutputMarker> {
pub struct InferenceFut<'s, 'r, 'v, O: SelectedOutputMarker> {
inner: Arc<InferenceFutInner<'r, 's>>,
run_options: RunOptionsRef<'r, O>,
did_receive: bool
did_receive: bool,
_inputs: PhantomData<&'v ()>
}

impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, O> {
impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, '_, O> {
pub(crate) fn new(inner: Arc<InferenceFutInner<'r, 's>>, run_options: RunOptionsRef<'r, O>) -> Self {
Self {
inner,
run_options,
did_receive: false
did_receive: false,
_inputs: PhantomData
}
}
}

impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, O> {
impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, '_, O> {
type Output = Result<SessionOutputs<'r, 's>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Expand All @@ -113,7 +116,7 @@ impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, O> {
}
}

impl<O: SelectedOutputMarker> Drop for InferenceFut<'_, '_, O> {
impl<O: SelectedOutputMarker> Drop for InferenceFut<'_, '_, '_, O> {
fn drop(&mut self) {
if !self.did_receive {
let _ = self.run_options.terminate();
Expand All @@ -122,9 +125,9 @@ impl<O: SelectedOutputMarker> Drop for InferenceFut<'_, '_, O> {
}
}

pub(crate) struct AsyncInferenceContext<'r, 's> {
pub(crate) struct AsyncInferenceContext<'r, 's, 'v> {
pub(crate) inner: Arc<InferenceFutInner<'r, 's>>,
pub(crate) _input_values: Vec<SessionInputValue<'s>>,
pub(crate) _input_values: Vec<SessionInputValue<'v>>,
pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>,
pub(crate) input_name_ptrs: Vec<*const c_char>,
pub(crate) output_name_ptrs: Vec<*const c_char>,
Expand All @@ -134,7 +137,7 @@ pub(crate) struct AsyncInferenceContext<'r, 's> {
}

pub(crate) extern "system" fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_, '_>>()) };

// Reconvert name ptrs to CString so drop impl is called and memory is freed
for p in ctx.input_name_ptrs {
Expand Down
10 changes: 5 additions & 5 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ impl Session {
/// ```
pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>> + 'static
) -> Result<InferenceFut<'s, 's, NoSelectedOutputs>> {
input_values: impl Into<SessionInputs<'i, 'v, N>>
) -> Result<InferenceFut<'s, 's, 'v, NoSelectedOutputs>> {
match input_values.into() {
SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"),
SessionInputs::ValueArray(input_values) => {
Expand All @@ -372,9 +372,9 @@ impl Session {
/// See [`Session::run_with_options`] and [`Session::run_async`] for more details.
pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>> + 'static,
input_values: impl Into<SessionInputs<'i, 'v, N>>,
run_options: &'r RunOptions<O>
) -> Result<InferenceFut<'s, 'r, O>> {
) -> Result<InferenceFut<'s, 'r, 'v, O>> {
match input_values.into() {
SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"),
SessionInputs::ValueArray(input_values) => {
Expand All @@ -393,7 +393,7 @@ impl Session {
input_names: &[String],
input_values: impl Iterator<Item = SessionInputValue<'v>>,
run_options: Option<&'r RunOptions<O>>
) -> Result<InferenceFut<'s, 'r, O>> {
) -> Result<InferenceFut<'s, 'r, 'v, O>> {
let run_options = match run_options {
Some(r) => RunOptionsRef::Ref(r),
// create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial
Expand Down

0 comments on commit cf447cf

Please sign in to comment.