From 1a505493fe50a7a30df2d0eee4f533cc5414a1f1 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Thu, 26 Dec 2024 19:05:39 -0600 Subject: [PATCH] docs: update docs to reflect new `Tensor` changes --- docs/pages/fundamentals/value.mdx | 174 +++++++++++++++++++----- docs/pages/migrating/v2.mdx | 46 ++----- docs/pages/perf/execution-providers.mdx | 2 +- docs/pages/perf/io-binding.mdx | 4 +- docs/pages/setup/cargo-features.mdx | 2 +- src/session/input.rs | 7 - src/value/impl_tensor/create.rs | 104 ++++++++++---- src/value/impl_tensor/extract.rs | 2 +- src/value/type.rs | 4 +- 9 files changed, 236 insertions(+), 109 deletions(-) diff --git a/docs/pages/fundamentals/value.mdx b/docs/pages/fundamentals/value.mdx index bbd78e6..7eff1ab 100644 --- a/docs/pages/fundamentals/value.mdx +++ b/docs/pages/fundamentals/value.mdx @@ -9,12 +9,134 @@ For ONNX Runtime, a **value** represents any type that can be given to/returned - **Maps** map a key type to a value type, similar to Rust's `HashMap`. - **Sequences** are homogenously-typed dynamically-sized lists, similar to Rust's `Vec`. The only values allowed in sequences are tensors, or maps of tensors. -In order to actually use the data in these containers, you can use the `.try_extract_*` methods. `try_extract_tensor(_mut)` extracts an `ndarray::ArrayView(Mut)` from the value if it is a tensor. `try_extract_sequence` returns a `Vec` of values, and `try_extract_map` returns a `HashMap`. +## Creating values -Sessions in `ort` return a map of `DynValue`s. You can determine a value's type via its `.dtype()` method. You can also use fallible methods to extract data from this value - for example, [`DynValue::try_extract_tensor`](https://docs.rs/ort/2.0.0-rc.8/ort/type.DynValue.html#method.try_extract_tensor), which fails if the value is not a tensor. Often times though, you'll want to reuse the same value which you are certain is a tensor - in which case, you can **downcast** the value. +### Creating tensors +Tensors can be created with [`Tensor::from_array`](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Tensor.html#method.from_array) from either: +- an [`ndarray::Array`](https://docs.rs/ndarray/0.16.1/ndarray/type.Array.html), or +- a tuple of `(shape, data)`, where: + - `shape` is one of `Vec`, `[I; N]` or `&[I]`, where `I` is `i64` or `usize`, and + - `data` is one of `Vec` or `Box<[T]>`. -## Downcasting -**Downcasting** means to convert a `Dyn` type like `DynValue` to stronger type like `DynTensor`. Downcasting can be performed using the `.downcast()` function on `DynValue`: +```rs +let tensor = Tensor::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; + +let tensor = Tensor::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]))?; +``` + +The created tensor will take ownership of the passed data. See [Creating views of external data](#creating-views-of-external-data) to create temporary tensors referencing borrowed data. + +### Creating maps & sequences +`Map`s can be [created](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Map.html#method.new) from any iterator yielding tuples of `(K, V)`, where `K` and `V` are tensor element types. + +```rs +let mut map = HashMap::::new(); +map.insert("one".to_string(), 1.0); +map.insert("two".to_string(), 2.0); +map.insert("three".to_string(), 3.0); + +let map = Map::::new(map)?; +``` + +`Map`s can also be [created from 2 tensors](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Map.html#method.new_kv), one containing keys and the other containing values: +```rs +let keys = Tensor::::from_array(([4], vec![0, 1, 2, 3]))?; +let values = Tensor::::from_array(([4], vec![1., 2., 3., 4.]))?; + +let map = Map::new_kv(keys, values)?; +``` + +`Sequence`s can be [created] from any iterator yielding a `Value` subtype: +```rs +let tensor1 = Tensor::::new(&allocator, [1, 128, 128, 3])?; +let tensor2 = Tensor::::new(&allocator, [1, 224, 224, 3])?; + +let sequence: Sequence> = Sequence::new(vec![tensor1, tensor2])?; +``` + +## Using values +Values can be used as an input in a session's [`run`](https://docs.rs/ort/2.0.0-rc.9/ort/session/struct.Session.html#method.run) function - either by value, by reference, or [by view](#views). +```rs +let latents = Tensor::::new(&allocator, [1, 128, 128, 3])?; +let text_embedding = Tensor::::new(&allocator, [1, 48, 256])?; +let timestep = Tensor::::new(&allocator, [1])?; + +let outputs = session.run(ort::inputs![ + "timestep" => timestep, + "latents" => &latents, + "text_embedding" => text_embedding.view() +])?; +``` + +### Extracting data +To access the underlying data of a value directly, the data must first be **extracted**. + +`Tensor`s can either [extract to an `ArrayView`](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Tensor.html#method.extract_tensor), or [extract to a tuple](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Tensor.html#method.extract_raw_tensor) of `(&[i64], &[T])`, where the first element is the shape of the tensor, and the second is the slice of data contained within the tensor. +```rs +let array = ndarray::Array4::::ones((1, 16, 16, 3)); +let tensor = TensorRef::from_array_view(&array)?; + +let extracted: ArrayViewD<'_, f32> = tensor.extract_tensor(); +let (tensor_shape, extracted_data): (&[i64], &[f32]) = tensor.extract_raw_tensor(); +``` + +`Tensor`s and `TensorRefMut`s with non-string elements can also be mutably extracted with `extract_tensor_mut` and `extract_raw_tensor_mut`. Mutating the returned types will directly update the data contained within the tensor. +```rs +let mut original_array = vec![1_i64, 2, 3, 4, 5]; +{ + let mut tensor = TensorRefMut::from_array_view_mut(([original_array.len()], &mut *original_array))?; + let (extracted_shape, extracted_data) = tensor.extract_raw_tensor_mut(); + extracted_data[2] = 42; +} +assert_eq!(original_array, [1, 2, 42, 4, 5]); +``` + +`Map` and `Sequence` have [`Map::extract_map`](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Map.html#method.extract_map) and [`Sequence::extract_sequence`](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Sequence.html#method.extract_sequence), which emit a `HashMap` and a `Vec` of value [views](#views) respectively. Unlike `extract_tensor`, these types cannot mutably extract their data, and always allocate on each `extract` call, making them more computationally expensive. + +Session outputs return `DynValue`s, which are values whose [type is not known at compile time](#dynamic-values). In order to extract data from a `DynValue`, you must either [downcast it to a strong type](#downcasting) or use a corresponding `try_extract_*` method, which fails if the value's type is not compatible: +```rs +let outputs = session.run(ort::inputs![TensorRef::from_array_view(&input)?])?; + +let Ok(tensor_output): ort::Result> = outputs[0].try_extract_tensor() else { + panic!("First output was not a Tensor!"); +} +``` + +## Views +A view (also called a ref) is functionally a borrowed variant of a value. There are also mutable views, which are equivalent to mutably borrowed values. Views are represented as separate structs so that they can be down/upcasted. + +View types are suffixed with `Ref` or `RefMut` for shared/mutable variants respectively: +- Tensors have `DynTensorRef(Mut)` and `TensorRef(Mut)`. +- Maps have `DynMapRef(Mut)` and `MapRef(Mut)`. +- Sequences have `DynSequenceRef(Mut)` and `SequenceRef(Mut)`. + +These views can be acquired with `.view()` or `.view_mut()` on a value type: +```rs +let my_tensor: ort::value::Tensor = Tensor::new(...)?; + +let tensor_view: ort::value::TensorRef<'_, f32> = my_tensor.view(); +``` + +Views act identically to a borrow of their type - `TensorRef` supports `extract_tensor`, `TensorRefMut` supports `extract_tensor_mut`. The same is true for sequences & maps. + +### Creating views of external data +You can create `TensorRef`s and `TensorRefMut`s from views of external data, like an `ndarray` array, or a raw slice of data. These types act almost identically to a `Tensor` - you can extract them and pass them as session inputs - but as they do not take ownership of the data, they are bound to the input's lifetime. + +```rs +let original_data = Array4::::from_shape_vec(...); +let tensor_view = TensorRef::from_array_view(original_data.view())?; + +let mut original_data = vec![...]; +let tensor_view_mut = TensorRefMut::from_array_view_mut(([1, 3, 64, 64], &mut *original_data))?; +``` + +## Dynamic values +Sessions in `ort` return a map of `DynValue`s. These are values whose exact type is not known at compile time. You can determine a value's [type](https://docs.rs/ort/2.0.0-rc.9/ort/value/enum.ValueType.html) via its `.dtype()` method. + +You can also use fallible methods to extract data from this value - for example, [`DynValue::try_extract_tensor`](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.DynValue.html#method.try_extract_tensor), which fails if the value is not a tensor. Often times though, you'll want to reuse the same value which you are certain is a tensor - in which case, you can **downcast** the value. + +### Downcasting +**Downcasting** means to convert a dyn type like `DynValue` to stronger type like `DynTensor`. Downcasting can be performed using the `.downcast()` function on `DynValue`: ```rs let value: ort::value::DynValue = outputs.remove("output0").unwrap(); @@ -23,10 +145,8 @@ let dyn_tensor: ort::value::DynTensor = value.downcast()?; If `value` is not actually a tensor, the `downcast()` call will fail. -`DynTensor` allows you to use - -### Stronger types -`DynTensor` means that the type **is** a tensor, but the *element type is unknown*. There are also `DynSequence`s and `DynMap`s, which have the same meaning - the element/key/value types are unknown. +#### Stronger types +`DynTensor` means that the type **is** a tensor, but the *element type is unknown*. There are also `DynSequence`s and `DynMap`s, which have the same meaning - the *kind* of value is known, but the element/key/value types are not. The strongly typed variants of these types - `Tensor`, `Sequence`, and `Map`, can be directly downcasted to, too: ```rs @@ -47,7 +167,7 @@ let tensor: ort::value::Tensor = dyn_value.downcast()?; let f32_array = tensor.extract_tensor(); // no `?` required, this will never fail! ``` -## Upcasting +### Upcasting **Upcasting** means to convert a strongly-typed value type like `Tensor` to a weaker type like `DynTensor` or `DynValue`. This can be useful if you have code that stores values of different types, e.g. in a `HashMap`. Strongly-typed value types like `Tensor` can be converted into a `DynTensor` using `.upcast()`: @@ -64,7 +184,17 @@ let dyn_value = f32_tensor.into_dyn(); Upcasting a value doesn't change its underlying type; it just removes the specialization. You cannot, for example, upcast a `Tensor` to a `DynValue` and then downcast it to a `Sequence`; it's still a `Tensor`, just contained in a different type. -## Conversion recap +### Dyn views +Views also support down/upcasting via `.downcast()` & `.into_dyn()` (but not `.upcast()` at the moment). + +You can also directly downcast a value to a stronger-typed view using `.downcast_ref()` and `.downcast_mut()`: +```rs +let tensor_view: ort::value::TensorRef<'_, f32> = dyn_value.downcast_ref()?; +// is equivalent to +let tensor_view: ort::value::TensorRef<'_, f32> = dyn_value.view().downcast()?; +``` + +### Conversion recap - `DynValue` represents a value that can be any type - tensor, sequence, or map. The type can be retrieved with `.dtype()`. - `DynTensor`, `DynMap`, and `DynSequence` are values with known container types, but unknown element types. - `Tensor`, `Map`, and `Sequence` are values with known container and element types. @@ -78,27 +208,3 @@ Upcasting a value doesn't change its underlying type; it just removes the specia Downcasts are cheap, as they only check the value's type. Upcasts compile to a no-op. - -## Views -A view (also called a ref) is functionally a borrowed variant of a value. There are also mutable views, which are equivalent to mutably borrowed values. Views are represented as separate structs so that they can be down/upcasted. - -View types are suffixed with `Ref` or `RefMut` for shared/mutable variants respectively: -- Tensors have `DynTensorRef(Mut)` and `TensorRef(Mut)`. -- Maps have `DynMapRef(Mut)` and `MapRef(Mut)`. -- Sequences have `DynSequenceRef(Mut)` and `SequenceRef(Mut)`. - -These views can be acquired with `.view()` or `.view_mut()` on a value type: -```rs -let my_tensor: ort::value::Tensor = Tensor::new(...)?; - -let tensor_view: ort::value::TensorRef<'_, f32> = my_tensor.view(); -``` - -Views act identically to a borrow of their type - `TensorRef` supports `extract_tensor`, `TensorRefMut` supports `extract_tensor_mut`. The same is true for sequences & maps. Views also support down/upcasting via `.downcast()` & `.into_dyn()` (but not `.upcast()` at the moment). - -You can also directly downcast a value to a stronger-typed view using `.downcast_ref()` and `.downcast_mut()`: -```rs -let tensor_view: ort::value::TensorRef<'_, f32> = dyn_value.downcast_ref()?; -// is equivalent to -let tensor_view: ort::value::TensorRef<'_, f32> = dyn_value.view().downcast()?; -``` diff --git a/docs/pages/migrating/v2.mdx b/docs/pages/migrating/v2.mdx index 6cad73d..ff3fefe 100644 --- a/docs/pages/migrating/v2.mdx +++ b/docs/pages/migrating/v2.mdx @@ -97,25 +97,8 @@ The final `SessionBuilder` methods have been renamed for clarity. ## Session inputs -### `CowArray`/`IxDyn`/`ndarray` no longer required -One of the biggest usability changes is that the usual pattern of `CowArray::from(array.into_dyn())` is no longer required to create tensors. Now, tensors can be created from: -- Owned `Array`s of any dimensionality -- `ArrayView`s of any dimensionality -- Shared references to `CowArray`s of any dimensionality (i.e. `&CowArray<'_, f32, Ix3>`) -- Mutable references to `ArcArray`s of any dimensionality (i.e. `&mut ArcArray`) -- A raw shape definition & data array, of type `(Vec, Arc>)` - -```diff --// v1.x --let mut tokens = CowArray::from(Array1::from_iter(tokens.iter().cloned()).into_dyn()); -+// v2 -+let mut tokens = Array1::from_iter(tokens.iter().cloned()); -``` - -It should be noted that there are some cases in which an array is cloned when converting into a tensor which may lead to a surprising performance hit. ONNX Runtime does not expose an API to specify the strides of a tensor, so if an array is reshaped before being converted into a tensor, it must be cloned in order to make the data contiguous. Specifically: -- `&CowArray`, `ArrayView` will **always be cloned** (due to the fact that we cannot guarantee the lifetime of the array). -- `Array`, `&mut ArcArray` will only be cloned **if the memory layout is not contiguous**, i.e. if it has been reshaped. -- Raw data will never be cloned as it is assumed to already have a contiguous memory layout. +### Tensor creation +You can now create input tensors from `Array`s and `ArrayView`s. See the [tensor value documentation](/fundamentals/value#creating-values) for more information. ### `ort::inputs!` macro @@ -127,18 +110,16 @@ The `ort::inputs!` macro will painlessly convert compatible data types (see abov -// v1.x -let chunk_embeddings = text_encoder.run(&[CowArray::from(text_input_chunk.into_dyn())])?; +// v2 -+let chunk_embeddings = text_encoder.run(ort::inputs![text_input_chunk]?)?; ++let chunk_embeddings = text_encoder.run(ort::inputs![text_input_chunk])?; ``` -Note the `?` after the macro call - `ort::inputs!` returns an `ort::Result`, so you'll need to handle any errors accordingly. - As mentioned, you can now also specify inputs by name using a map-like syntax. This is especially useful for graphs with optional inputs. ```rust let noise_pred = unet.run(ort::inputs![ - "latents" => latents, - "timestep" => Array1::from_iter([t]), + "latents" => &latents, + "timestep" => Tensor::from_array(([1], vec![t]))?, "encoder_hidden_states" => text_embeddings.view() -]?)?; +])?; ``` ### Tensor creation no longer requires the session's allocator @@ -148,19 +129,19 @@ In previous versions, `Value::from_array` took an allocator parameter. The alloc -// v1.x -let val = Value::from_array(session.allocator(), &array)?; +// v2 -+let val = Tensor::from_array(&array)?; ++let val = Tensor::from_array(array)?; ``` ### Separate string tensor creation As previously mentioned, the logic for creating string tensors has been moved from `Value::from_array` to `DynTensor::from_string_array`. -To use string tensors with `ort::inputs!`, you must create a `DynTensor` using `DynTensor::from_string_array`. +To use string tensors with `ort::inputs!`, you must create a `Tensor` using `Tensor::from_string_array`. ```rust let array = ndarray::Array::from_shape_vec((1,), vec![document]).unwrap(); let outputs = session.run(ort::inputs![ - "input" => DynTensor::from_string_array(session.allocator(), array)? -]?)?; + "input" => Tensor::from_string_array(session.allocator(), array)? +])?; ``` ## Session outputs @@ -173,7 +154,7 @@ let l = outputs["latents"].try_extract_tensor::()?; ``` ## Execution providers -Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.8/ort/index.html?search=ExecutionProvider) and the [execution providers reference](/perf/execution-providers) for more information. +Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.9/ort/execution_providers/index.html#reexports) and the [execution providers reference](/perf/execution-providers) for more information. ```diff -// v1.x @@ -190,8 +171,11 @@ Execution provider structs with public fields have been replaced with builder pa ## Updated dependencies & features +### `ndarray` 0.16 +The `ndarray` dependency has been upgraded to 0.16. In order to convert tensors from `ndarray`, your application must update to `ndarray` 0.16 as well. + ### `ndarray` is now optional -The dependency on `ndarray` is now declared optional. If you use `ort` with `default-features = false`, you'll need to add the `ndarray` feature. +The dependency on `ndarray` is now optional. If you previously used `ort` with `default-features = false`, you'll need to add the `ndarray` feature to keep using `ndarray` integration. ## Model Zoo structs have been removed ONNX pushed a new Model Zoo structure that adds hundreds of different models. This is impractical to maintain, so the built-in structs have been removed. diff --git a/docs/pages/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx index cb11af0..4d8cd06 100644 --- a/docs/pages/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> { ``` ## Configuring EPs -EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.8/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do. +EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.9/ort/execution_providers/index.html#reexports) for the EP structs for more information on which options are supported and what they do. ```rust use ort::{execution_providers::CoreMLExecutionProvider, session::Session}; diff --git a/docs/pages/perf/io-binding.mdx b/docs/pages/perf/io-binding.mdx index b5eeefb..0de5df2 100644 --- a/docs/pages/perf/io-binding.mdx +++ b/docs/pages/perf/io-binding.mdx @@ -12,7 +12,7 @@ In some cases, this I/O overhead is unavoidable -- a causal language model, for For these cases, ONNX Runtime provides **I/O binding**, an interface that allows you to manually specify which inputs/outputs reside on which device, and control when they are synchronized. ## Creating -I/O binding is used via the [`IoBinding`](https://docs.rs/ort/2.0.0-rc.8/ort/struct.IoBinding.html) struct. `IoBinding` is created using the [`Session::create_binding`](https://docs.rs/ort/2.0.0-rc.8/ort/struct.Session.html#method.create_binding) method: +I/O binding is used via the [`IoBinding`](https://docs.rs/ort/2.0.0-rc.9/ort/io_binding/struct.IoBinding.html) struct. `IoBinding` is created using the [`Session::create_binding`](https://docs.rs/ort/2.0.0-rc.9/ort/session/struct.Session.html#method.create_binding) method: ```rs let mut binding = session.create_binding()?; @@ -57,7 +57,7 @@ binding.bind_output_to_device("action", &allocator.memory_info())?; This means that subsequent runs will *override* the data in `action`. If you need to access a bound output's data *across* runs (i.e. in a multithreading setting), the data needs to be copied to another buffer to avoid undefined behavior. -Outputs can be bound to any device -- they can even stay on the EP device if you bind it to a tensor created with the session's allocator (`Tensor::new(session.allocator(), ...)`). You can then access the pointer to device memory using [`Tensor::data_ptr`](https://docs.rs/ort/2.0.0-rc.8/ort/type.Tensor.html#method.data_ptr). +Outputs can be bound to any device -- they can even stay on the EP device if you bind it to a tensor created with the session's allocator (`Tensor::new(session.allocator(), ...)`). You can then access the pointer to device memory using [`Tensor::data_ptr`](https://docs.rs/ort/2.0.0-rc.9/ort/value/type.Tensor.html#method.data_ptr). If you do bind an output to the session's device, it is not guaranteed to be synchronized after `run`, just like `bind_input`. You can force outputs to synchronize immediately using `IoBinding::synchronize_outputs`. diff --git a/docs/pages/setup/cargo-features.mdx b/docs/pages/setup/cargo-features.mdx index 0433de3..1ef2d07 100644 --- a/docs/pages/setup/cargo-features.mdx +++ b/docs/pages/setup/cargo-features.mdx @@ -9,7 +9,7 @@ title: Cargo features - ✅ **`half`**: Enables support for float16 & bfloat16 tensors via the [`half`](https://crates.io/crates/half) crate. ONNX models that are converted to 16-bit precision will typically convert to/from 32-bit floats at the input/output, so you will likely never actually need to interact with a 16-bit tensor on the Rust side. Though, `half` isn't a heavy enough crate to worry about it affecting compile times. - ✅ **`copy-dylibs`**: In case dynamic libraries are used (like with the CUDA execution provider), creates a symlink to them in the relevant places in the `target` folder to make [compile-time dynamic linking](/setup/linking#compile-time-dynamic-linking) work. - ⚒️ **`load-dynamic`**: Enables [runtime dynamic linking](/setup/linking#runtime-loading-with-load-dynamic), which alleviates many of the troubles with compile-time dynamic linking and offers greater flexibility. -- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.8/ort/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. +- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.9/ort/session/builder/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. ## Execution providers Each [execution provider](/perf/execution-providers) is also gated behind a Cargo feature. diff --git a/src/session/input.rs b/src/session/input.rs index 4ad1bfe..b308c0b 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -76,11 +76,6 @@ impl<'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<'_, /// Construct the inputs to a session from an array or named map of values. /// -/// See [`Value::from_array`] for details on what types a tensor can be created from. -/// -/// Note that the output of this macro is a `Result`, so make sure to handle any potential -/// errors. -/// /// # Example /// /// ## Array of values @@ -110,8 +105,6 @@ impl<'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<'_, /// # Ok(()) /// # } /// ``` -/// -/// [`Tensor::from_string_array`]: crate::value::Tensor::from_string_array #[macro_export] macro_rules! inputs { ($($v:expr),+ $(,)?) => ( diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 3595eab..0d09914 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -23,15 +23,13 @@ use crate::{ impl Tensor { /// Construct a [`Tensor`] from an array of strings. /// - /// Just like numeric tensors, string tensors can be created from: - /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); - /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); - /// - (with feature `ndarray`) an owned [`ndarray::Array`]; - /// - (with feature `ndarray`) a borrowed view of another array, as an [`ndarray::ArrayView`] (`ArrayView<'_, T, - /// D>`); - /// - a tuple of `(dimensions, data)` where: - /// * `dimensions` is one of `Vec`, `[I]` or `&[I]`, where `I` is `i64` or `usize`; - /// * and `data` is one of `Vec`, `Box<[T]>`, `Arc>`, or `&[T]`. + /// String tensors can be created from: + /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`) or + /// [`ndarray::Array`] (`&Array`); + /// - (with feature `ndarray`) an [`ndarray::ArcArray`] or [`ndarray::ArrayView`]; + /// - a tuple of `(shape, data)` where: + /// * `shape` is one of `Vec`, `[I; N]` or `&[I]`, where `I` is `i64` or `usize`, and + /// * `data` is one of `&[T]`, `Arc<[T]>`, or `Arc>`. /// /// ``` /// # use ort::{session::Session, value::Tensor}; @@ -46,8 +44,6 @@ impl Tensor { /// # Ok(()) /// # } /// ``` - /// - /// Note that string data will *always* be copied, no matter what form the data is provided in. pub fn from_string_array(input: impl TensorArrayData) -> Result> { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); @@ -155,17 +151,13 @@ impl Tensor { }) } - /// Construct a tensor from an array of data. + /// Construct an owned tensor from an array of data. /// - /// Tensors can be created from: - /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); - /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); - /// - (with feature `ndarray`) an owned [`ndarray::Array`]; - /// - (with feature `ndarray`) a borrowed view of another array, as an [`ndarray::ArrayView`] (`ArrayView<'_, T, - /// D>`); - /// - a tuple of `(dimensions, data)` where: - /// * `dimensions` is one of `Vec`, `[I]` or `&[I]`, where `I` is `i64` or `usize`; - /// * and `data` is one of `Vec`, `Box<[T]>`, `Arc>`, or `&[T]`. + /// Owned tensors can be created from: + /// - (with feature `ndarray`) an owned [`ndarray::Array`], or + /// - a tuple of `(shape, data)` where: + /// * `shape` is one of `Vec`, `[I]` or `&[I]`, where `I` is `i64` or `usize`, and + /// * `data` is one of `Vec` or `Box<[T]>`. /// /// ``` /// # use ort::value::Tensor; @@ -180,16 +172,11 @@ impl Tensor { /// # } /// ``` /// - /// Creating string tensors requires a separate method; see [`DynTensor::from_string_array`]. - /// - /// Note that data provided in an `ndarray` may be copied in some circumstances: - /// - `&CowArray<'_, T, D>` will always be copied regardless of whether it is uniquely owned or borrowed. - /// - `&mut ArcArray` and `Array` will be copied only if the data is not in a contiguous layout (which - /// is the case after most reshape operations) - /// - `ArrayView<'_, T, D>` will always be copied. + /// When passing an [`ndarray::Array`], the array may be copied in order to convert it to a contiguous layout if it + /// is not already. When creating a tensor from a `Vec` or boxed slice, the data is assumed to already be in + /// contiguous layout. /// - /// Raw data provided as a `Arc>`, `Box<[T]>`, or `Vec` will never be copied. Raw data is expected to be - /// in standard, contigous layout. + /// Creating string tensors requires a separate method; see [`Tensor::from_string_array`]. pub fn from_array(input: impl OwnedTensorArrayData) -> Result> { let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; @@ -234,6 +221,35 @@ impl Tensor { } impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> { + /// Construct a tensor from borrowed data. + /// + /// Borrowed tensors can be created from: + /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`) or + /// [`ndarray::Array`] (`&Array`); + /// - (with feature `ndarray`) an [`ndarray::ArcArray`] or [`ndarray::ArrayView`]; + /// - a tuple of `(shape, data)` where: + /// * `shape` is one of `Vec`, `[I; N]` or `&[I]`, where `I` is `i64` or `usize`, and + /// * `data` is one of `&[T]`, `Arc<[T]>`, or `Arc>`. + /// + /// ``` + /// # use ort::value::TensorRef; + /// # fn main() -> ort::Result<()> { + /// // Create a tensor from a raw data vector + /// let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + /// let tensor = TensorRef::from_array_view(([1usize, 2, 3], &*data))?; + /// + /// // Create a tensor from an `ndarray::Array` + /// # #[cfg(feature = "ndarray")] + /// # { + /// let array = ndarray::Array4::::zeros((1, 16, 16, 3)); + /// let tensor = TensorRef::from_array_view(array.view())?; + /// # } + /// # Ok(()) + /// # } + /// ``` + /// + /// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be + /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view(input: impl TensorArrayData + 'a) -> Result> { let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; @@ -281,6 +297,34 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> { } impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { + /// Construct a tensor from mutably borrowed data. Modifying data with [`Value::extract_tensor_mut`] will modify the + /// underlying buffer as well. + /// + /// Mutably borrowed tensors can be created from: + /// - (with feature `ndarray`) an exclusive reference to an [`ndarray::Array`] (`&Array`); + /// - (with feature `ndarray`) an [`ndarray::ArrayViewMut`]; + /// - a tuple of `(shape, &mut [T])`, where `shape` is one of `Vec`, `[I; N]` or `&[I]`, where `I` is `i64` or + /// `usize`. + /// + /// ``` + /// # use ort::value::TensorRefMut; + /// # fn main() -> ort::Result<()> { + /// // Create a tensor from a raw data vector + /// let mut data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + /// let tensor = TensorRefMut::from_array_view_mut(([1usize, 2, 3], &mut *data))?; + /// + /// // Create a tensor from an `ndarray::Array` + /// # #[cfg(feature = "ndarray")] + /// # { + /// let mut array = ndarray::Array4::::zeros((1, 16, 16, 3)); + /// let tensor = TensorRefMut::from_array_view_mut(array.view_mut())?; + /// # } + /// # Ok(()) + /// # } + /// ``` + /// + /// When passing an [`ndarray`] type, the data **must** have a contiguous memory layout, or else an error will be + /// returned. See [`ndarray::ArrayBase::as_standard_layout`] to convert an array to a contiguous layout. pub fn from_array_view_mut(mut input: impl TensorArrayDataMut) -> Result> { let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?; diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index 5d9baa1..b2773a5 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -355,7 +355,7 @@ impl Value { /// # use ort::value::Tensor; /// # fn main() -> ort::Result<()> { /// let array = vec!["hello", "world"]; - /// let tensor = Tensor::from_string_array(([array.len()], array.clone().into_boxed_slice()))?.into_dyn(); + /// let tensor = Tensor::from_string_array(([array.len()], &*array))?.into_dyn(); /// /// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?; /// assert_eq!(extracted_data, array); diff --git a/src/value/type.rs b/src/value/type.rs index 27983d6..5164db0 100644 --- a/src/value/type.rs +++ b/src/value/type.rs @@ -5,7 +5,7 @@ use std::{ use crate::{ortsys, tensor::TensorElementType}; -/// The type of a [`Value`], or a session input/output. +/// The type of a [`Value`][super::Value], or a session input/output. /// /// ``` /// # use std::sync::Arc; @@ -73,7 +73,7 @@ pub enum ValueType { /// The map value type. value: TensorElementType }, - /// An optional value, which may or may not contain a [`Value`]. + /// An optional value, which may or may not contain a [`Value`][super::Value]. Optional(Box) }