Skip to content

Commit

Permalink
Make mapv_into_any() work for ArcArray, resolves rust-ndarray#1280
Browse files Browse the repository at this point in the history
  • Loading branch information
benkay86 committed Oct 24, 2024
1 parent 492b274 commit ac9e3da
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 12 deletions.
42 changes: 32 additions & 10 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::mem::{size_of, ManuallyDrop};
use crate::imp_prelude::*;

use crate::argument_traits::AssignElem;
use crate::data_traits::RawDataSubst;
use crate::dimension;
use crate::dimension::broadcast::co_broadcast;
use crate::dimension::reshape_dim;
Expand Down Expand Up @@ -2814,15 +2815,29 @@ where
/// map is performed as in [`mapv`].
///
/// Elements are visited in arbitrary order.
///
///
/// Note that the compiler will need some hint about the return type, which
/// is generic over [`DataOwned`], and can thus be an [`Array`] or
/// [`ArcArray`]. Example:
///
/// ```rust
/// # use ndarray::{array, Array};
/// let a = array![[1., 2., 3.]];
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
/// ```
///
/// [`mapv_into`]: ArrayBase::mapv_into
/// [`mapv`]: ArrayBase::mapv
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
pub fn mapv_into_any<B, F, T>(self, mut f: F) -> ArrayBase<T, D>
where
S: DataMut,
S: DataMut<Elem = A>,
F: FnMut(A) -> B,
A: Clone + 'static,
B: 'static,
T: DataOwned<Elem = B> + RawDataSubst<A> + 'static, // lets us introspect on the types of array representations containing different data elements
<T as RawDataSubst<A>>::Output: RawData, // required by mapv_into()
ArrayBase<<T as RawDataSubst<A>>::Output, D>: From<ArrayBase<S, D>>, // required by into() to convert from the DataMut array representation of S to the DataOwned array representation of T
ArrayBase<T, D>: From<Array<B, D>>, // required by mapv()
{
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
// A and B are the same type.
Expand All @@ -2832,16 +2847,23 @@ where
// Safe because A and B are the same type.
unsafe { unlimited_transmute::<B, A>(b) }
};
// Delegate to mapv_into() using the wrapped closure.
// Convert output to a uniquely owned array of type Array<A, D>.
let output = self.mapv_into(f).into_owned();
// Change the return type from Array<A, D> to Array<B, D>.
// Again, safe because A and B are the same type.
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
// Delegate to mapv_into() to map from element type A to type A.
let output = self.mapv_into(f);
// Convert from S's data storage to T's data storage.
// Suppose `T is `OwnedRepr<B>`.
// Then `<T as RawDataSubst<A>>::Output` is `OwnedRepr<A>`.
let output: ArrayBase<<T as RawDataSubst<A>>::Output, D> = output.into();
// Since A == B and T stores elements of type B, it should be true
// that <T as RawDataSubst<A>>::Output == T.
// Verify that this is indeed the case.
assert!(core::any::TypeId::of::<<T as RawDataSubst<A>>::Output>() == core::any::TypeId::of::<T>());
// Now we can safely transmute the element type from A to the
// identical type B, keeping the same data storage.
unsafe { unlimited_transmute::<ArrayBase<<T as RawDataSubst<A>>::Output, D>, ArrayBase<T,D>>(output) }
} else {
// A and B are not the same type.
// Fallback to mapv().
self.mapv(f)
self.mapv(f).into()
}
}

Expand Down
30 changes: 28 additions & 2 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1054,15 +1054,41 @@ fn mapv_into_any_same_type()
{
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
let b: Array<_, _> = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_diff_types()
{
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
fn mapv_into_any_arcarray_same_type() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_arcarray_diff_types() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
fn mapv_into_any_diff_outer_types() {
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
Expand Down

0 comments on commit ac9e3da

Please sign in to comment.