Skip to content

Commit

Permalink
feat: expose tensor array conversion traits
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 26, 2024
1 parent 794c041 commit 6778056
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
38 changes: 38 additions & 0 deletions src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,23 @@ impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> {
}

pub trait TensorArrayData<I> {
#[allow(clippy::type_complexity)]
fn ref_parts(&self) -> Result<(Vec<i64>, &[I], Option<Box<dyn Any>>)>;

crate::private_trait!();
}

pub trait TensorArrayDataMut<I>: TensorArrayData<I> {
#[allow(clippy::type_complexity)]
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [I], Option<Box<dyn Any>>)>;

crate::private_trait!();
}

pub trait OwnedTensorArrayData<I> {
fn into_parts(self) -> Result<TensorArrayDataParts<I>>;

crate::private_trait!();
}

pub struct TensorArrayDataParts<I> {
Expand Down Expand Up @@ -478,6 +486,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &CowArra
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data, None))
}

crate::private_impl!();
}

#[cfg(feature = "ndarray")]
Expand All @@ -490,6 +500,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArcArray
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data, Some(Box::new(self.clone()))))
}

crate::private_impl!();
}

#[cfg(feature = "ndarray")]
Expand All @@ -502,6 +514,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for &Array<T
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data, None))
}

crate::private_impl!();
}

#[cfg(feature = "ndarray")]
Expand All @@ -525,6 +539,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> OwnedTensorArrayData<T> for Arr
Ok(TensorArrayDataParts { shape, ptr, num_elements, guard })
}
}

crate::private_impl!();
}

#[cfg(feature = "ndarray")]
Expand All @@ -537,6 +553,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayVie
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data, None))
}

crate::private_impl!();
}

#[cfg(feature = "ndarray")]
Expand All @@ -549,6 +567,8 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayData<T> for ArrayVie
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data, None))
}

crate::private_impl!();
}

#[cfg(feature = "ndarray")]
Expand All @@ -561,27 +581,35 @@ impl<T: Clone + 'static, D: Dimension + 'static> TensorArrayDataMut<T> for Array
.ok_or_else(|| Error::new("Array has a non-contiguous layout and cannot be used to construct a Tensor"))?;
Ok((shape, data, None))
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &[T]) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1, None))
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, &mut [T]) {
fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1, None))
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayDataMut<T> for (D, &mut [T]) {
fn ref_parts_mut(&mut self) -> Result<(Vec<i64>, &mut [T], Option<Box<dyn Any>>)> {
let shape = self.0.to_dimensions(Some(self.1.len()))?;
Ok((shape, self.1, None))
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Vec<T>) {
Expand All @@ -596,6 +624,8 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Vec<T>
guard: Box::new(self.1)
})
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Box<[T]>) {
Expand All @@ -604,6 +634,8 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Box<[T]>) {
let data = &*self.1;
Ok((shape, data, None))
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T]>) {
Expand All @@ -618,6 +650,8 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T
guard: Box::new(self.1)
})
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) {
Expand All @@ -626,6 +660,8 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) {
let data = &*self.1;
Ok((shape, data, Some(Box::new(self.1.clone()))))
}

crate::private_impl!();
}

impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<Box<[T]>>) {
Expand All @@ -634,4 +670,6 @@ impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<Box<[T]
let data = &*self.1;
Ok((shape, data, Some(Box::new(self.1.clone()))))
}

crate::private_impl!();
}
1 change: 1 addition & 0 deletions src/value/impl_tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
sync::Arc
};

pub use self::create::{OwnedTensorArrayData, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts};
use super::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType, ValueTypeMarker};
use crate::{AsPointer, error::Result, memory::MemoryInfo, ortsys, tensor::IntoTensorElementType};

Expand Down
5 changes: 4 additions & 1 deletion src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ pub use self::{
impl_sequence::{
DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker
},
impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker},
impl_tensor::{
DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, OwnedTensorArrayData, Tensor, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts,
TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker
},
r#type::ValueType
};
use crate::{
Expand Down

0 comments on commit 6778056

Please sign in to comment.