diff --git a/src/impl_dyn.rs b/src/impl_dyn.rs index 836234cec..b86c5dd69 100644 --- a/src/impl_dyn.rs +++ b/src/impl_dyn.rs @@ -58,4 +58,60 @@ where S: Data self.dim = self.dim.remove_axis(axis); self.strides = self.strides.remove_axis(axis); } + + /// Remove axes of length 1 and return the modified array. + /// + /// If the array has more the one dimension, the result array will always + /// have at least one dimension, even if it has a length of 1. + /// + /// ``` + /// use ndarray::{arr1, arr2, arr3}; + /// + /// let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn(); + /// assert_eq!(a.shape(), &[2, 1, 3]); + /// let b = a.squeeze(); + /// assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn()); + /// assert_eq!(b.shape(), &[2, 3]); + /// + /// let c = arr2(&[[1]]).into_dyn(); + /// assert_eq!(c.shape(), &[1, 1]); + /// let d = c.squeeze(); + /// assert_eq!(d, arr1(&[1]).into_dyn()); + /// assert_eq!(d.shape(), &[1]); + /// ``` + #[track_caller] + pub fn squeeze(self) -> Self + { + let mut out = self; + for axis in (0..out.shape().len()).rev() { + if out.shape()[axis] == 1 && out.shape().len() > 1 { + out = out.remove_axis(Axis(axis)); + } + } + out + } +} + +#[cfg(test)] +mod tests +{ + use crate::{arr1, arr2, arr3}; + + #[test] + fn test_squeeze() + { + let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn(); + assert_eq!(a.shape(), &[2, 1, 3]); + + let b = a.squeeze(); + assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn()); + assert_eq!(b.shape(), &[2, 3]); + + let c = arr2(&[[1]]).into_dyn(); + assert_eq!(c.shape(), &[1, 1]); + + let d = c.squeeze(); + assert_eq!(d, arr1(&[1]).into_dyn()); + assert_eq!(d.shape(), &[1]); + } }