Skip to content

Commit

Permalink
refactor!: precompute value dtype/memory info
Browse files Browse the repository at this point in the history
Breaking because `extract_tensor_*` now returns `&[i64]` for dimensions, and `dtype()` and `memory_info()` also return references.

Each tensor extract call not only had multiple FFI calls to determine the `ValueType`, but also had to determine `MemoryInfo` to ensure the data was CPU-accessible. Since neither the data type or memory location can *change* for a given value, it doesn't make sense to compute this on each extract call; it's better to compute it once, when we create the `Value` (and we often already have the types created by this time, so little FFI is actually required).

This should make `extract_tensor_raw` zero-alloc, most benefitting usages of `IoBinding`/`OutputSelector`. This does mean usages of `Value` without ever extracting said value (like HF Transformers hidden state outputs which go ignored) incur slightly more overhead, but the tradeoff of having less overhead at extraction time seems worth it.
  • Loading branch information
decahedron1 committed Nov 16, 2024
1 parent e34092b commit 1dbad54
Show file tree
Hide file tree
Showing 15 changed files with 447 additions and 460 deletions.
8 changes: 4 additions & 4 deletions examples/custom-ops/examples/custom-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ impl Kernel for CustomOpOneKernel {
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let (y_shape, y) = y.try_extract_raw_tensor::<f32>()?;

let mut z = ctx.output(0, x_shape)?.unwrap();
let mut z = ctx.output(0, x_shape.to_vec())?.unwrap();
let (_, z_ref) = z.try_extract_raw_tensor_mut::<f32>()?;
for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize {
for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize {
if i % 2 == 0 {
z_ref[i] = x[i];
} else {
Expand Down Expand Up @@ -79,9 +79,9 @@ impl Kernel for CustomOpTwoKernel {
fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> {
let x = ctx.input(0)?.unwrap();
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let mut z = ctx.output(0, x_shape.clone())?.unwrap();
let mut z = ctx.output(0, x_shape.to_vec())?.unwrap();
let (_, z_ref) = z.try_extract_raw_tensor_mut::<i32>()?;
for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize {
for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap() as usize {
z_ref[i] = (x[i] * i as f32) as i32;
}
Ok(())
Expand Down
44 changes: 3 additions & 41 deletions examples/model-info/examples/model-info.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,6 @@
use std::{env, process};

use ort::{session::Session, tensor::TensorElementType, value::ValueType};

fn display_element_type(t: TensorElementType) -> &'static str {
match t {
TensorElementType::Bfloat16 => "bf16",
TensorElementType::Bool => "bool",
TensorElementType::Float16 => "f16",
TensorElementType::Float32 => "f32",
TensorElementType::Float64 => "f64",
TensorElementType::Int16 => "i16",
TensorElementType::Int32 => "i32",
TensorElementType::Int64 => "i64",
TensorElementType::Int8 => "i8",
TensorElementType::String => "str",
TensorElementType::Uint16 => "u16",
TensorElementType::Uint32 => "u32",
TensorElementType::Uint64 => "u64",
TensorElementType::Uint8 => "u8"
}
}

fn display_value_type(value: &ValueType) -> String {
match value {
ValueType::Tensor { ty, dimensions } => {
format!(
"Tensor<{}>({})",
display_element_type(*ty),
dimensions
.iter()
.map(|c| if *c == -1 { "dyn".to_string() } else { c.to_string() })
.collect::<Vec<_>>()
.join(", ")
)
}
ValueType::Map { key, value } => format!("Map<{}, {}>", display_element_type(*key), display_element_type(*value)),
ValueType::Sequence(inner) => format!("Sequence<{}>", display_value_type(inner)),
ValueType::Optional(inner) => format!("Option<{}>", display_value_type(inner))
}
}
use ort::session::Session;

fn main() -> ort::Result<()> {
let Some(path) = env::args().nth(1) else {
Expand All @@ -61,11 +23,11 @@ fn main() -> ort::Result<()> {

println!("Inputs:");
for (i, input) in session.inputs.iter().enumerate() {
println!(" {i} {}: {}", input.name, display_value_type(&input.input_type));
println!(" {i} {}: {}", input.name, input.input_type);
}
println!("Outputs:");
for (i, output) in session.outputs.iter().enumerate() {
println!(" {i} {}: {}", output.name, display_value_type(&output.output_type));
println!(" {i} {}: {}", output.name, output.output_type);
}

Ok(())
Expand Down
10 changes: 3 additions & 7 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::{
collections::HashMap,
ffi::CString,
fmt::Debug,
marker::PhantomData,
ptr::{self, NonNull},
sync::Arc
};
Expand Down Expand Up @@ -214,7 +213,7 @@ impl IoBinding {
let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { std::ptr::null() };
ortsys![unsafe RunWithBinding(self.session.ptr().cast_mut(), run_options_ptr, self.ptr())?];

let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc<ValueInner>> = self.output_values.values().map(|c| (c.ptr().cast_mut(), &c.inner)).collect();
let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Value> = self.output_values.values().map(|c| (c.ptr().cast_mut(), c)).collect();
let mut count = self.output_names.len();
if count > 0 {
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
Expand All @@ -223,11 +222,8 @@ impl IoBinding {
let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count).to_vec() }
.into_iter()
.map(|v| unsafe {
if let Some(inner) = owned_ptrs.get(&v) {
DynValue {
inner: Arc::clone(*inner),
_markers: PhantomData
}
if let Some(value) = owned_ptrs.get(&v) {
DynValue::clone_of(value)
} else {
DynValue::from_ptr(
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
Expand Down
13 changes: 13 additions & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,19 @@ impl MemoryInfo {
})
}

pub(crate) fn from_value(value_ptr: *mut ort_sys::OrtValue) -> Option<Self> {
let mut is_tensor = 0;
ortsys![unsafe IsTensor(value_ptr, &mut is_tensor)]; // infallible
if is_tensor != 0 {
let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut();
// infallible, and `memory_info_ptr` will never be null
ortsys![unsafe GetTensorMemoryInfo(value_ptr, &mut memory_info_ptr)];
Some(Self::from_raw(unsafe { NonNull::new_unchecked(memory_info_ptr.cast_mut()) }, false))
} else {
None
}
}

pub(crate) fn from_raw(ptr: NonNull<ort_sys::OrtMemoryInfo>, should_release: bool) -> Self {
MemoryInfo { ptr, should_release }
}
Expand Down
8 changes: 7 additions & 1 deletion src/operator/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
use crate::{
AsPointer,
error::{Error, Result, status_to_result},
memory::{Allocator, MemoryInfo},
memory::{Allocator, MemoryInfo, MemoryType},
ortsys,
session::{Input, Output},
value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType}
Expand Down Expand Up @@ -89,6 +89,12 @@ impl KernelAttributes {
ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
CString::from_vec_with_nul(name).map_err(Error::wrap)?.into_string().map_err(Error::wrap)
}

pub fn allocator(&self, mem_type: MemoryType) -> Result<Allocator> {
let mut ptr: *mut ort_sys::OrtAllocator = ptr::null_mut();
ortsys![unsafe KernelInfoGetAllocator(self.0.as_ptr(), mem_type.into(), &mut ptr)?];
Ok(unsafe { Allocator::from_raw_unchecked(ptr) })
}
}

impl AsPointer for KernelAttributes {
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ impl Kernel for CustomOpOneKernel {
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let (y_shape, y) = y.try_extract_raw_tensor::<f32>()?;

let mut z = ctx.output(0, x_shape)?.ok_or_else(|| crate::Error::new("missing input"))?;
let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?;
let (_, z_ref) = z.try_extract_raw_tensor_mut::<f32>()?;
for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
for i in 0..y_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
if i % 2 == 0 {
z_ref[i] = x[i];
} else {
Expand Down Expand Up @@ -81,9 +81,9 @@ impl Kernel for CustomOpTwoKernel {
fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()> {
let x = ctx.input(0)?.ok_or_else(|| crate::Error::new("missing input"))?;
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let mut z = ctx.output(0, x_shape.clone())?.ok_or_else(|| crate::Error::new("missing input"))?;
let mut z = ctx.output(0, x_shape.to_vec())?.ok_or_else(|| crate::Error::new("missing input"))?;
let (_, z_ref) = z.try_extract_raw_tensor_mut::<i32>()?;
for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
for i in 0..x_shape.iter().copied().reduce(|acc, e| acc * e).unwrap_or(0) as usize {
z_ref[i] = (x[i] * i as f32) as i32;
}
Ok(())
Expand Down
9 changes: 2 additions & 7 deletions src/session/output.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::{
ffi::c_void,
iter::FusedIterator,
marker::PhantomData,
mem::ManuallyDrop,
ops::{Index, IndexMut},
ptr,
sync::Arc
ptr
};

use crate::{
Expand Down Expand Up @@ -113,10 +111,7 @@ impl<'r, 's> SessionOutputs<'r, 's> {
if &key == k {
*k = "";
self.effective_len -= 1;
return Some(DynValue {
inner: Arc::clone(&self.values[i].inner),
_markers: PhantomData
});
return Some(DynValue::clone_of(&self.values[i]));
}
}
None
Expand Down
10 changes: 1 addition & 9 deletions src/session/run_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,7 @@ impl OutputSelector {
.map(|o| &o.name)
.filter(|n| !self.default_blocklist.contains(n))
.chain(self.allowlist.iter())
.map(|n| {
(
n.as_str(),
self.preallocated_outputs.get(n).map(|v| DynValue {
inner: Arc::clone(&v.inner),
_markers: PhantomData
})
)
})
.map(|n| (n.as_str(), self.preallocated_outputs.get(n).map(DynValue::clone_of)))
.unzip()
}
}
Expand Down
21 changes: 13 additions & 8 deletions src/value/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
match self.dtype() {
ValueType::Map { key, value } => {
let k_type = K::into_tensor_element_type();
if k_type != key {
if k_type != *key {
return Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot extract Map<{:?}, _> (value has K type {:?})", k_type, key)));
}
let v_type = V::into_tensor_element_type();
if v_type != value {
if v_type != *value {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract Map<{}, {}> from Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type(), k_type, v_type)
Expand All @@ -100,7 +100,7 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
if K::into_tensor_element_type() != TensorElementType::String {
let dtype = key_value.dtype();
let (key_tensor_shape, key_tensor) = match dtype {
ValueType::Tensor { ty, dimensions } => {
ValueType::Tensor { ty, dimensions, .. } => {
let mem = key_value.memory_info();
if !mem.is_cpu_accessible() {
return Err(Error::new(format!(
Expand All @@ -109,13 +109,13 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
)));
}

if ty == K::into_tensor_element_type() {
if *ty == K::into_tensor_element_type() {
let mut output_array_ptr: *mut K = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast();
ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void)?; nonNull(output_array_ptr)];

let len = calculate_tensor_size(&dimensions);
let len = calculate_tensor_size(dimensions);
(dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })
} else {
return Err(Error::new_with_code(
Expand Down Expand Up @@ -251,10 +251,15 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq + 'static, V: IntoTens
nonNull(value_ptr)
];
Ok(Value {
inner: Arc::new(ValueInner::RustOwned {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(values),
_memory_info: None
dtype: ValueType::Map {
key: K::into_tensor_element_type(),
value: V::into_tensor_element_type()
},
drop: true,
memory_info: None,
_backing: Some(Box::new(values))
}),
_markers: PhantomData
})
Expand Down
12 changes: 8 additions & 4 deletions src/value/impl_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {

let value = unsafe { Value::from_ptr(NonNull::new_unchecked(value_ptr), None) };
let value_type = value.dtype();
if !OtherType::can_downcast(&value.dtype()) {
if !OtherType::can_downcast(value.dtype()) {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract Sequence<{}> from {value_type:?}", OtherType::format())
Expand Down Expand Up @@ -134,10 +134,14 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<Se
nonNull(value_ptr)
];
Ok(Value {
inner: Arc::new(ValueInner::RustOwned {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(values),
_memory_info: None
// 1. `CreateValue` enforces that we have at least 1 value
// 2. `CreateValue` internally uses the first value to determine the element type, so we do the same here
dtype: ValueType::Sequence(Box::new(values[0].inner.dtype.clone())),
drop: true,
memory_info: None,
_backing: Some(Box::new(values))
}),
_markers: PhantomData
})
Expand Down
48 changes: 36 additions & 12 deletions src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType},
ortsys,
tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data},
value::{DynValue, Value, ValueInner}
value::{DynValue, Value, ValueInner, ValueType}
};

impl Tensor<String> {
Expand Down Expand Up @@ -76,10 +76,16 @@ impl Tensor<String> {
ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len())?];

Ok(Value {
inner: Arc::new(ValueInner::RustOwned {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(()),
_memory_info: None
dtype: ValueType::Tensor {
ty: TensorElementType::String,
dimensions: shape,
dimension_symbols: vec![None; shape_len]
},
memory_info: MemoryInfo::from_value(value_ptr),
drop: true,
_backing: None
}),
_markers: PhantomData
})
Expand Down Expand Up @@ -124,10 +130,16 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
];

Ok(Value {
inner: Arc::new(ValueInner::RustOwned {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: Box::new(()),
_memory_info: None
dtype: ValueType::Tensor {
ty: T::into_tensor_element_type(),
dimensions: shape,
dimension_symbols: vec![None; shape_len]
},
drop: true,
memory_info: MemoryInfo::from_value(value_ptr),
_backing: None
}),
_markers: PhantomData
})
Expand Down Expand Up @@ -195,10 +207,16 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
];

Ok(Value {
inner: Arc::new(ValueInner::RustOwned {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
_array: guard,
_memory_info: Some(memory_info)
dtype: ValueType::Tensor {
ty: T::into_tensor_element_type(),
dimensions: shape,
dimension_symbols: vec![None; shape_len]
},
drop: true,
memory_info: Some(memory_info),
_backing: Some(guard)
}),
_markers: PhantomData
})
Expand Down Expand Up @@ -252,10 +270,16 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
];

Ok(TensorRefMut::new(Value {
inner: Arc::new(ValueInner::CppOwned {
inner: Arc::new(ValueInner {
ptr: unsafe { NonNull::new_unchecked(value_ptr) },
dtype: ValueType::Tensor {
ty: T::into_tensor_element_type(),
dimensions: shape,
dimension_symbols: vec![None; shape_len]
},
drop: true,
_session: None
memory_info: Some(info),
_backing: None
}),
_markers: PhantomData
}))
Expand Down
Loading

0 comments on commit 1dbad54

Please sign in to comment.