From 2d6bcdeb31387eddc1294fd9c90b77eae0a97143 Mon Sep 17 00:00:00 2001 From: bluss Date: Wed, 5 May 2021 00:08:18 +0200 Subject: [PATCH] FEAT: Add method squeeze_into This method can squeeze into a particular dimensionality. Squeezing means removing axes of length 1. When squeezing to a particular dimensionality, we may have to still pad out the shape with extra 1-shape axes to fill the dimensionality. --- src/dimension/mod.rs | 143 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 9 deletions(-) diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 2584152b3..186096a5f 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -729,6 +729,8 @@ where } /// Remove axes with length one, except never removing the last axis. +/// +/// This function is a no-op for const dim. pub(crate) fn squeeze(dim: &mut D, strides: &mut D) where D: Dimension, @@ -736,6 +738,29 @@ where if let Some(_) = D::NDIM { return; } + + // infallible for dyn dim + let (d, s) = squeeze_into(dim, strides).unwrap(); + *dim = d; + *strides = s; +} + +/// Remove axes with length one, except never removing the last axis. +/// +/// Return an error if there are more non-unitary dimensions than can be stored +/// in `E`. Infallible for dyn dim. +/// +/// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is +/// dynamic 0D, the output can be too. +/// +/// For const dim, this may instead pad the dimensionality with ones if it needs +/// to grow to fill the target dimensionality; the dimension is padded in the +/// start. +pub(crate) fn squeeze_into(dim: &D, strides: &D) -> Result<(E, E), ShapeError> +where + D: Dimension, + E: Dimension, +{ debug_assert_eq!(dim.ndim(), strides.ndim()); // Count axes with dim == 1; we keep axes with d == 0 or d > 1 @@ -743,10 +768,30 @@ where for &d in dim.slice() { if d != 1 { ndim_new += 1; } } - ndim_new = Ord::max(1, ndim_new); - let mut new_dim = D::zeros(ndim_new); - let mut new_strides = D::zeros(ndim_new); + let mut fill_ones = 0; + if let Some(e_ndim) = E::NDIM { + if e_ndim < ndim_new { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + fill_ones = e_ndim - ndim_new; + ndim_new = e_ndim; + } else { + // dynamic-dimensional + // use minimum one dimension unless input has less than one dim + if dim.ndim() > 0 && ndim_new == 0 { + ndim_new = 1; + fill_ones = 1; + } + } + + let mut new_dim = E::zeros(ndim_new); + let mut new_strides = E::zeros(ndim_new); let mut i = 0; + while i < fill_ones { + new_dim[i] = 1; + new_strides[i] = 1; + i += 1; + } for (&d, &s) in izip!(dim.slice(), strides.slice()) { if d != 1 { new_dim[i] = d; @@ -754,12 +799,7 @@ where i += 1; } } - if i == 0 { - new_dim[i] = 1; - new_strides[i] = 1; - } - *dim = new_dim; - *strides = new_strides; + Ok((new_dim, new_strides)) } @@ -1148,6 +1188,91 @@ mod test { assert_eq!(s, sans); } + #[test] + #[cfg(feature = "std")] + fn test_squeeze_into() { + use super::squeeze_into; + + let dyndim = Dim::<&[usize]>; + + // squeeze to ixdyn + let d = dyndim(&[1, 2, 1, 1, 3, 1]); + let s = dyndim(&[!0, !0, !0, 9, 10, !0]); + let dans = dyndim(&[2, 3]); + let sans = dyndim(&[!0, 10]); + let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap(); + assert_eq!(d2, dans); + assert_eq!(s2, sans); + + // squeeze to ixdyn does not go below 1D + let d = dyndim(&[1, 1]); + let s = dyndim(&[3, 4]); + let dans = dyndim(&[1]); + let sans = dyndim(&[1]); + let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap(); + assert_eq!(d2, dans); + assert_eq!(s2, sans); + + let d = Dim([1, 1]); + let s = Dim([3, 4]); + let dans = Dim([1]); + let sans = Dim([1]); + let (d2, s2) = squeeze_into::<_, Ix1>(&d, &s).unwrap(); + assert_eq!(d2, dans); + assert_eq!(s2, sans); + + // squeeze to zero-dim + let (d2, s2) = squeeze_into::<_, Ix0>(&d, &s).unwrap(); + assert_eq!(d2, Ix0()); + assert_eq!(s2, Ix0()); + + let d = Dim([0, 1, 3, 4]); + let s = Dim([2, 3, 4, 5]); + let dans = Dim([0, 3, 4]); + let sans = Dim([2, 4, 5]); + let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap(); + assert_eq!(d2, dans); + assert_eq!(s2, sans); + + // Pad with ones + let d = Dim([0, 1, 3, 1]); + let s = Dim([2, 3, 4, 5]); + let dans = Dim([1, 0, 3]); + let sans = Dim([1, 2, 4]); + let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap(); + assert_eq!(d2, dans); + assert_eq!(s2, sans); + + // Try something that doesn't fit + let d = Dim([0, 1, 3, 1]); + let s = Dim([2, 3, 4, 5]); + let res = squeeze_into::<_, Ix1>(&d, &s); + assert!(res.is_err()); + let res = squeeze_into::<_, Ix0>(&d, &s); + assert!(res.is_err()); + + // Squeeze 0d to 0d + let d = Dim([]); + let s = Dim([]); + let res = squeeze_into::<_, Ix0>(&d, &s); + assert!(res.is_ok()); + // grow 0d to 2d + let dans = Dim([1, 1]); + let sans = Dim([1, 1]); + let (d2, s2) = squeeze_into::<_, Ix2>(&d, &s).unwrap(); + assert_eq!(d2, dans); + assert_eq!(s2, sans); + + // Squeeze 0d to 0d dynamic + let d = dyndim(&[]); + let s = dyndim(&[]); + let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap(); + let dans = d; + let sans = s; + assert_eq!(d2, dans); + assert_eq!(s2, sans); + } + #[test] fn test_merge_axes_from_the_back() { let dyndim = Dim::<&[usize]>;