diff --git a/sparse/_common.py b/sparse/_common.py index d8a37d97..576206c9 100644 --- a/sparse/_common.py +++ b/sparse/_common.py @@ -2249,6 +2249,11 @@ def astype(x, dtype, /, *, copy=True): return x.astype(dtype, copy=copy) +@_support_numpy +def squeeze(x, /, axis=None): + return x.squeeze(axis=axis) + + @_support_numpy def broadcast_to(x, /, shape): return x.broadcast_to(shape) diff --git a/sparse/_coo/core.py b/sparse/_coo/core.py index eff98bd7..c7932772 100644 --- a/sparse/_coo/core.py +++ b/sparse/_coo/core.py @@ -1086,6 +1086,59 @@ def reshape(self, shape, order="C"): self._cache["reshape"].append((shape, result)) return result + def squeeze(self, axis=None): + """ + Removes singleton dimensions (axes) from ``x``. + Parameters + ---------- + axis : Union[None, int, Tuple[int, ...]] + The axis (or axes) to squeeze. If a specified axis has a size greater than one, + a `ValueError` is raised. ``axis=None`` removes all singleton dimensions. + Default: ``None``. + Returns + ------- + COO + The output array without ``axis`` dimensions. + Examples + -------- + >>> s = COO.from_numpy(np.eye(2)).reshape((2, 1, 2, 1)) + >>> s.squeeze().shape + (2, 2) + >>> s.squeeze(axis=1).shape + (2, 2, 1) + """ + squeezable_dims = tuple([d for d in range(self.ndim) if self.shape[d] == 1]) + + if axis is None: + axis = squeezable_dims + if isinstance(axis, int): + axis = (axis,) + elif isinstance(axis, Iterable): + axis = tuple(axis) + else: + raise ValueError(f"Invalid axis parameter: `{axis}`.") + + for d in axis: + if not d in squeezable_dims: + raise ValueError( + f"Specified axis `{d}` has a size greater than one: {self.shape[d]}" + ) + + retained_dims = [d for d in range(self.ndim) if not d in axis] + + coords = self.coords[retained_dims, :] + shape = tuple([s for idx, s in enumerate(self.shape) if idx in retained_dims]) + + return COO( + coords, + self.data, + shape, + has_duplicates=False, + sorted=True, + cache=self._cache is not None, + fill_value=self.fill_value, + ) + def resize(self, *args, refcheck=True, coords_dtype=np.intp): """ This method changes the shape and size of an array in-place. diff --git a/sparse/tests/test_coo.py b/sparse/tests/test_coo.py index b64d8a06..1f811768 100644 --- a/sparse/tests/test_coo.py +++ b/sparse/tests/test_coo.py @@ -1752,3 +1752,43 @@ def test_isinf_isnan(config): expected = getattr(np, func_name)(arr) np.testing.assert_equal(result, expected) + + +class TestSqueeze: + eye_arr = np.eye(2).reshape(1, 2, 1, 2) + + @pytest.mark.parametrize( + "arr_and_axis", + [ + (eye_arr, None), + (eye_arr, 0), + (eye_arr, 2), + (eye_arr, (0, 2)), + (np.zeros((5,)), None), + ], + ) + def test_squeeze(self, arr_and_axis): + arr, axis = arr_and_axis + + s_arr = sparse.COO.from_numpy(arr) + + result_1 = sparse.squeeze(s_arr, axis=axis).todense() + result_2 = s_arr.squeeze(axis=axis).todense() + expected = np.squeeze(arr, axis=axis) + + np.testing.assert_equal(result_1, result_2) + np.testing.assert_equal(result_1, expected) + + def test_squeeze_validation(self): + s_arr = sparse.COO.from_numpy(np.eye(3)) + + with pytest.raises(IndexError, match="tuple index out of range"): + s_arr.squeeze(3) + + with pytest.raises(ValueError, match="Invalid axis parameter: `1.1`."): + s_arr.squeeze(1.1) + + with pytest.raises( + ValueError, match="Specified axis `0` has a size greater than one: 3" + ): + s_arr.squeeze(0)