Skip to content

Commit

Permalink
fix: TensorRef::from_array_view lifetime
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 21, 2024
1 parent 9ea18d8 commit c8bd1cc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/value/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl<V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType
/// ```
pub fn new(data: impl IntoIterator<Item = (String, V)>) -> Result<Self> {
let (keys, values): (Vec<String>, Vec<V>) = data.into_iter().unzip();
Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?)
Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys.as_slice()))?, Tensor::from_array((vec![values.len()], values))?)
}
}

Expand Down
85 changes: 32 additions & 53 deletions src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ use std::{
sync::Arc
};

use ndarray::ArrayViewMut;
#[cfg(feature = "ndarray")]
use ndarray::{ArcArray, Array, ArrayView, CowArray, Dimension};
use ndarray::{ArcArray, Array, ArrayView, ArrayViewMut, CowArray, Dimension};

use super::{Tensor, TensorRef, TensorRefMut, calculate_tensor_size};
use crate::{
Expand Down Expand Up @@ -52,7 +51,7 @@ impl Tensor<String> {
pub fn from_string_array<T: Utf8Data>(input: impl TensorArrayData<T>) -> Result<Tensor<String>> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();

let (shape, data) = input.ref_parts()?;
let (shape, data, _guard) = input.ref_parts()?;
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = shape.len();

Expand Down Expand Up @@ -225,13 +224,13 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
}

impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> {
pub fn from_array_view(input: impl TensorArrayData<T>) -> Result<TensorRef<'a, T>> {
pub fn from_array_view(input: impl TensorArrayData<T> + 'a) -> Result<TensorRef<'a, T>> {
let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?;

let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();

// f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime
let (shape, data) = input.ref_parts()?;
let (shape, data, guard) = input.ref_parts()?;
let num_elements = calculate_tensor_size(&shape);
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = shape.len();
Expand Down Expand Up @@ -262,7 +261,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRef<'a, T> {
},
drop: true,
memory_info: Some(memory_info),
_backing: None
_backing: guard
}),
_markers: PhantomData
});
Expand All @@ -278,7 +277,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();

// f16 and bf16 are repr(transparent) to u16, so memory layout should be identical to onnxruntime
let (shape, data) = input.ref_parts_mut()?;
let (shape, data, guard) = input.ref_parts_mut()?;
let num_elements = calculate_tensor_size(&shape);
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = shape.len();
Expand Down Expand Up @@ -309,7 +308,7 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
},
drop: true,
memory_info: Some(memory_info),
_backing: None
_backing: guard
}),
_markers: PhantomData
});
Expand Down Expand Up @@ -382,14 +381,14 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
}

pub trait TensorArrayData<I> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[I])>;
fn ref_parts(&self) -> Result<(Vec<i64>, &[I], Option<Box<dyn Any>>)>;
}

pub trait TensorArrayDataMut<I>: TensorArrayData<I> {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [I])>;
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [I], Option<Box<dyn Any>>)>;
}

pub trait OwnedTensorArrayData<I>: TensorArrayData<I> {
pub trait OwnedTensorArrayData<I> {
fn into_parts(self) -> Result<TensorArrayDataParts<I>>;
}

Expand Down Expand Up @@ -462,48 +461,36 @@ impl_to_dimensions!(<N> for [usize; N], for [i32; N], for [i64; N]);
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &CowArray<'_, T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data))
Ok((shape, data, None))
}
}

#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArcArray<T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data))
}
}

#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for Array<T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data))
Ok((shape, data, Some(Box::new(self.clone()))))
}
}

#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &Array<T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data))
Ok((shape, data, None))
}
}

Expand Down Expand Up @@ -533,65 +520,57 @@ impl<T: Clone + 'static, D: Dimension + 'static> OwnedTensorArrayData<T> for Arr
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayView<'_, T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data))
Ok((shape, data, None))
}
}

#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayViewMut<'_, T, D> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data))
Ok((shape, data, None))
}
}

#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for ArrayViewMut<'_, T, D> {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T])> {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
let shape: Vec<i64> = self.shape().iter().map(|d| *d as i64).collect();
let data = self
.as_slice_mut()
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data))
Ok((shape, data, None))
}
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &[T]) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1))
Ok((shape, self.1, None))
}
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &mut [T]) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1))
Ok((shape, self.1, None))
}
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayDataMut<T> for (D, &mut [T]) {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T])> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1))
}
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Vec<T>) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
let data = &*self.1;
Ok((shape, data))
Ok((shape, self.1, None))
}
}

Expand All @@ -610,10 +589,10 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Vec<T>
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Box<[T]>) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
let data = &*self.1;
Ok((shape, data))
Ok((shape, data, None))
}
}

Expand All @@ -632,9 +611,9 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<Box<[T]>>) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T])> {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
let data = &*self.1;
Ok((shape, data))
Ok((shape, data, Some(Box::new(self.1.clone()))))
}
}
5 changes: 0 additions & 5 deletions src/value/impl_tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,6 @@ mod tests {
let value = TensorRef::from_array_view(&cow)?;
assert_eq!(value.extract_raw_tensor().1, &v);

let owned = Array1::from_vec(v.clone());
let value = TensorRef::from_array_view(owned.view())?;
drop(owned);
assert_eq!(value.extract_raw_tensor().1, &v);

Ok(())
}

Expand Down

0 comments on commit c8bd1cc

Please sign in to comment.