Skip to content

Commit

Permalink
chore: update doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 21, 2024
1 parent c5a06a7 commit ec24307
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 83 deletions.
6 changes: 3 additions & 3 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl Drop for AdapterInner {
/// run_options.add_adapter(&lora)?;
///
/// let outputs =
/// model.run_with_options(ort::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?;
/// model.run_with_options(ort::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?], &run_options)?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -84,7 +84,7 @@ impl Adapter {
/// run_options.add_adapter(&lora)?;
///
/// let outputs =
/// model.run_with_options(ort::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?;
/// model.run_with_options(ort::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?], &run_options)?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -130,7 +130,7 @@ impl Adapter {
/// run_options.add_adapter(&lora)?;
///
/// let outputs =
/// model.run_with_options(ort::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?]?, &run_options)?;
/// model.run_with_options(ort::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?], &run_options)?;
/// # Ok(())
/// # }
/// ```
Expand Down
2 changes: 1 addition & 1 deletion src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use crate::{
/// .run(ort::inputs![Tensor::<i64>::from_array((vec![27], vec![
/// 23763, 15460, 473, 68, 312, 265, 17463, 4098, 304, 1077, 283, 198, 7676, 5976, 272, 285, 3609, 435, 21680,
/// 321, 265, 300, 1689, 64, 285, 4763, 64
/// ]))?]?)?
/// ]))?])?
/// .remove("output0")
/// .unwrap();
///
Expand Down
25 changes: 6 additions & 19 deletions src/session/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,15 @@ impl<'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<'_,
///
/// # Example
///
/// ## Array of tensors
/// ## Array of values
///
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ndarray::Array1;
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # use ort::{value::Tensor, session::{builder::GraphOptimizationLevel, Session}};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
/// let _ = session.run(ort::inputs![Array1::from_vec(vec![1, 2, 3, 4, 5])]?);
/// # Ok(())
/// # }
/// ```
///
/// Note that string tensors must be created manually with [`Tensor::from_string_array`].
///
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ndarray::Array1;
/// # use ort::{session::{builder::GraphOptimizationLevel, Session}, value::Tensor};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
/// let _ = session.run(ort::inputs![Tensor::from_string_array(Array1::from_vec(vec!["hello", "world"]))?]?);
/// let _ = session.run(ort::inputs![Tensor::from_array(([5], vec![1, 2, 3, 4, 5]))?])?;
/// # Ok(())
/// # }
/// ```
Expand All @@ -114,12 +101,12 @@ impl<'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<'_,
/// ```no_run
/// # use std::{error::Error, sync::Arc};
/// # use ndarray::Array1;
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # use ort::{value::Tensor, session::{builder::GraphOptimizationLevel, Session}};
/// # fn main() -> Result<(), Box<dyn Error>> {
/// # let mut session = Session::builder()?.commit_from_file("model.onnx")?;
/// let _ = session.run(ort::inputs! {
/// "tokens" => Array1::from_vec(vec![1, 2, 3, 4, 5])
/// }?);
/// "tokens" => Tensor::from_array(([5], vec![1, 2, 3, 4, 5]))?
/// })?;
/// # Ok(())
/// # }
/// ```
Expand Down
24 changes: 12 additions & 12 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! Contains [`Session`], the main interface used to inference ONNX models.
//!
//! ```
//! # use ort::session::Session;
//! # use ort::{session::Session, value::TensorRef};
//! # fn main() -> ort::Result<()> {
//! let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
//! let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
//! let outputs = session.run(ort::inputs![input]?)?;
//! let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?;
//! # Ok(())
//! # }
//! ```
Expand Down Expand Up @@ -82,11 +82,11 @@ impl Drop for SharedSessionInner {
/// An ONNX Runtime graph to be used for inference.
///
/// ```
/// # use ort::session::Session;
/// # use ort::{session::Session, value::TensorRef};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
/// let outputs = session.run(ort::inputs![input]?)?;
/// let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -187,11 +187,11 @@ impl Session {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{run_options::RunOptions, Session}, tensor::TensorElementType, value::{Value, ValueType}};
/// # use ort::{session::{run_options::RunOptions, Session}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
/// let outputs = session.run(ort::inputs![input]?)?;
/// let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?;
/// # Ok(())
/// # }
/// ```
Expand All @@ -217,7 +217,7 @@ impl Session {
/// ```no_run
/// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough
/// # use std::sync::Arc;
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType}, tensor::TensorElementType};
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// # let input = Value::from_array(ndarray::Array4::<f32>::zeros((1, 64, 64, 3)))?;
Expand All @@ -228,7 +228,7 @@ impl Session {
/// let _ = run_options_.terminate();
/// });
///
/// let res = session.run_with_options(ort::inputs![input]?, &*run_options);
/// let res = session.run_with_options(ort::inputs![&input], &*run_options);
/// // upon termination, the session will return an `Error::SessionRun` error.`
/// assert_eq!(
/// &res.unwrap_err().to_string(),
Expand Down Expand Up @@ -345,11 +345,11 @@ impl Session {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType}, tensor::TensorElementType};
/// # use ort::{session::{Session, run_options::RunOptions}, value::{Value, ValueType, TensorRef}, tensor::TensorElementType};
/// # fn main() -> ort::Result<()> { tokio_test::block_on(async {
/// let session = Session::builder()?.with_intra_threads(2)?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
/// let outputs = session.run_async(ort::inputs![input]?)?.await?;
/// let outputs = session.run_async(ort::inputs![TensorRef::from_array_view(&input)?])?.await?;
/// # Ok(())
/// # }) }
/// ```
Expand Down Expand Up @@ -477,13 +477,13 @@ impl Session {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, tensor::TensorElementType, value::{Value, ValueType}};
/// # use ort::{session::{run_options::RunOptions, Session, WorkloadType}, tensor::TensorElementType, value::{Value, ValueType, TensorRef}};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// session.set_workload_type(WorkloadType::Efficient)?;
///
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
/// let outputs = session.run(ort::inputs![input]?)?;
/// let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?;
/// # Ok(())
/// # }
/// ```
Expand Down
4 changes: 2 additions & 2 deletions src/session/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ use crate::{
/// This type allows session outputs to be retrieved by index or by name.
///
/// ```
/// # use ort::session::{builder::GraphOptimizationLevel, Session};
/// # use ort::{value::TensorRef, session::{builder::GraphOptimizationLevel, Session}};
/// # fn main() -> ort::Result<()> {
/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let input = ndarray::Array4::<f32>::zeros((1, 64, 64, 3));
/// let outputs = session.run(ort::inputs![input]?)?;
/// let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?;
///
/// // get the first output
/// let output = &outputs[0];
Expand Down
10 changes: 5 additions & 5 deletions src/session/run_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::{
/// );
///
/// // `outputs[0]` will be the tensor we just pre-allocated.
/// let outputs = session.run_with_options(ort::inputs![input]?, &options)?;
/// let outputs = session.run_with_options(ort::inputs![input], &options)?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -108,7 +108,7 @@ impl OutputSelector {
/// OutputSelector::default().preallocate(output0, Tensor::<f32>::new(&Allocator::default(), [1, 128, 128, 3])?)
/// );
///
/// let outputs = session.run_with_options(ort::inputs![input]?, &options)?;
/// let outputs = session.run_with_options(ort::inputs![input], &options)?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -212,7 +212,7 @@ impl<O: SelectedOutputMarker> RunOptions<O> {
/// );
///
/// // `outputs[0]` will be the tensor we just pre-allocated.
/// let outputs = session.run_with_options(ort::inputs![input]?, &options)?;
/// let outputs = session.run_with_options(ort::inputs![input], &options)?;
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -262,7 +262,7 @@ impl<O: SelectedOutputMarker> RunOptions<O> {
/// let _ = run_options_.terminate();
/// });
///
/// let res = session.run_with_options(ort::inputs![input]?, &*run_options);
/// let res = session.run_with_options(ort::inputs![input], &*run_options);
/// // upon termination, the session will return an `Error::SessionRun` error.`
/// assert_eq!(
/// &res.unwrap_err().to_string(),
Expand Down Expand Up @@ -293,7 +293,7 @@ impl<O: SelectedOutputMarker> RunOptions<O> {
/// let _ = run_options_.unterminate();
/// });
///
/// let res = session.run_with_options(ort::inputs![input]?, &*run_options);
/// let res = session.run_with_options(ort::inputs![input], &*run_options);
/// assert!(res.is_ok());
/// # Ok(())
/// # }
Expand Down
4 changes: 2 additions & 2 deletions src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ impl Tensor<String> {
/// # fn main() -> ort::Result<()> {
/// // Create a string tensor from a raw data vector
/// let data = vec!["hello", "world"];
/// let value = Tensor::from_string_array(([data.len()], data.into_boxed_slice()))?;
/// let value = Tensor::from_string_array(([data.len()], &*data))?;
///
/// // Create a string tensor from an `ndarray::Array`
/// #[cfg(feature = "ndarray")]
/// let value = Tensor::from_string_array(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap())?;
/// let value = Tensor::from_string_array(&ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap())?;
/// # Ok(())
/// # }
/// ```
Expand Down
73 changes: 35 additions & 38 deletions src/value/impl_tensor/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::value::Tensor;
/// # use ort::value::TensorRef;
/// # fn main() -> ort::Result<()> {
/// let array = ndarray::Array4::<f32>::ones((1, 16, 16, 3));
/// let value = Tensor::from_array(array.view())?.into_dyn();
/// let value = TensorRef::from_array_view(array.view())?.into_dyn();
///
/// let extracted = value.try_extract_tensor::<f32>()?;
/// assert_eq!(array.into_dyn(), extracted);
/// assert_eq!(array.view().into_dyn(), extracted);
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -130,18 +130,16 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::value::Tensor;
/// # use ort::value::TensorRefMut;
/// # fn main() -> ort::Result<()> {
/// let array = ndarray::Array4::<f32>::ones((1, 16, 16, 3));
/// let mut value = Tensor::from_array(array.view())?.into_dyn();
///
/// let mut extracted = value.try_extract_tensor_mut::<f32>()?;
/// extracted[[0, 0, 0, 1]] = 0.0;
///
/// let mut array = array.into_dyn();
/// assert_ne!(array, extracted);
/// array[[0, 0, 0, 1]] = 0.0;
/// assert_eq!(array, extracted);
/// let mut array = ndarray::Array4::<f32>::ones((1, 16, 16, 3));
/// {
/// let mut value = TensorRefMut::from_array_view_mut(array.view_mut())?.into_dyn();
/// let mut extracted = value.try_extract_tensor_mut::<f32>()?;
/// extracted[[0, 0, 0, 1]] = 0.0;
/// }
///
/// assert_eq!(array[[0, 0, 0, 1]], 0.0);
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -290,7 +288,7 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
/// # use ort::value::Tensor;
/// # fn main() -> ort::Result<()> {
/// let array = ndarray::Array1::from_vec(vec!["hello", "world"]);
/// let tensor = Tensor::from_string_array(array.clone())?.into_dyn();
/// let tensor = Tensor::from_string_array(&array)?.into_dyn();
///
/// let extracted = tensor.try_extract_string_tensor()?;
/// assert_eq!(array.into_dyn(), extracted);
Expand Down Expand Up @@ -450,13 +448,13 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::value::Tensor;
/// # use ort::value::TensorRef;
/// # fn main() -> ort::Result<()> {
/// let array = ndarray::Array4::<f32>::ones((1, 16, 16, 3));
/// let tensor = Tensor::from_array(array.view())?;
/// let tensor = TensorRef::from_array_view(&array)?;
///
/// let extracted = tensor.extract_tensor();
/// assert_eq!(array.into_dyn(), extracted);
/// assert_eq!(array.view().into_dyn(), extracted);
/// # Ok(())
/// # }
/// ```
Expand All @@ -470,18 +468,16 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
///
/// ```
/// # use std::sync::Arc;
/// # use ort::value::Tensor;
/// # use ort::value::TensorRefMut;
/// # fn main() -> ort::Result<()> {
/// let array = ndarray::Array4::<f32>::ones((1, 16, 16, 3));
/// let mut tensor = Tensor::from_array(array.view())?;
///
/// let mut extracted = tensor.extract_tensor_mut();
/// extracted[[0, 0, 0, 1]] = 0.0;
///
/// let mut array = array.into_dyn();
/// assert_ne!(array, extracted);
/// array[[0, 0, 0, 1]] = 0.0;
/// assert_eq!(array, extracted);
/// let mut array = ndarray::Array4::<f32>::ones((1, 16, 16, 3));
/// {
/// let mut tensor = TensorRefMut::from_array_view_mut(array.view_mut())?;
/// let mut extracted = tensor.extract_tensor_mut();
/// extracted[[0, 0, 0, 1]] = 0.0;
/// }
///
/// assert_eq!(array[[0, 0, 0, 1]], 0.0);
/// # Ok(())
/// # }
/// ```
Expand All @@ -495,10 +491,10 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// view into its data.
///
/// ```
/// # use ort::value::Tensor;
/// # use ort::value::TensorRef;
/// # fn main() -> ort::Result<()> {
/// let array = vec![1_i64, 2, 3, 4, 5];
/// let tensor = Tensor::from_array(([array.len()], array.clone().into_boxed_slice()))?;
/// let tensor = TensorRef::from_array_view(([array.len()], &*array))?;
///
/// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor();
/// assert_eq!(extracted_data, &array);
Expand All @@ -514,14 +510,15 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// into its data.
///
/// ```
/// # use ort::value::Tensor;
/// # use ort::value::TensorRefMut;
/// # fn main() -> ort::Result<()> {
/// let array = vec![1_i64, 2, 3, 4, 5];
/// let tensor = Tensor::from_array(([array.len()], array.clone().into_boxed_slice()))?;
///
/// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor();
/// assert_eq!(extracted_data, &array);
/// assert_eq!(extracted_shape, [5]);
/// let mut original_array = vec![1_i64, 2, 3, 4, 5];
/// {
/// let mut tensor = TensorRefMut::from_array_view_mut(([original_array.len()], &mut *original_array))?;
/// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor_mut();
/// extracted_data[2] = 42;
/// }
/// assert_eq!(original_array, [1, 2, 42, 4, 5]);
/// # Ok(())
/// # }
/// ```
Expand Down
2 changes: 1 addition & 1 deletion src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl<Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'_, Type> {
/// let value = Tensor::from_array(ndarray::Array4::<f32>::zeros((1, 16, 16, 3)))?;
///
/// // Get a DynValue from a session's output
/// let value = &upsample.run(ort::inputs![value]?)?[0];
/// let value = &upsample.run(ort::inputs![value])?[0];
/// # Ok(())
/// # }
/// ```
Expand Down

0 comments on commit ec24307

Please sign in to comment.