Skip to content

Commit

Permalink
API: Add squeeze method to COO (#613)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol authored Dec 22, 2023
1 parent 01d8934 commit f522926
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
5 changes: 5 additions & 0 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,11 @@ def astype(x, dtype, /, *, copy=True):
return x.astype(dtype, copy=copy)


@_support_numpy
def squeeze(x, /, axis=None):
return x.squeeze(axis=axis)


@_support_numpy
def broadcast_to(x, /, shape):
return x.broadcast_to(shape)
Expand Down
53 changes: 53 additions & 0 deletions sparse/_coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,59 @@ def reshape(self, shape, order="C"):
self._cache["reshape"].append((shape, result))
return result

def squeeze(self, axis=None):
"""
Removes singleton dimensions (axes) from ``x``.
Parameters
----------
axis : Union[None, int, Tuple[int, ...]]
The axis (or axes) to squeeze. If a specified axis has a size greater than one,
a `ValueError` is raised. ``axis=None`` removes all singleton dimensions.
Default: ``None``.
Returns
-------
COO
The output array without ``axis`` dimensions.
Examples
--------
>>> s = COO.from_numpy(np.eye(2)).reshape((2, 1, 2, 1))
>>> s.squeeze().shape
(2, 2)
>>> s.squeeze(axis=1).shape
(2, 2, 1)
"""
squeezable_dims = tuple([d for d in range(self.ndim) if self.shape[d] == 1])

if axis is None:
axis = squeezable_dims
if isinstance(axis, int):
axis = (axis,)
elif isinstance(axis, Iterable):
axis = tuple(axis)
else:
raise ValueError(f"Invalid axis parameter: `{axis}`.")

for d in axis:
if not d in squeezable_dims:
raise ValueError(
f"Specified axis `{d}` has a size greater than one: {self.shape[d]}"
)

retained_dims = [d for d in range(self.ndim) if not d in axis]

coords = self.coords[retained_dims, :]
shape = tuple([s for idx, s in enumerate(self.shape) if idx in retained_dims])

return COO(
coords,
self.data,
shape,
has_duplicates=False,
sorted=True,
cache=self._cache is not None,
fill_value=self.fill_value,
)

def resize(self, *args, refcheck=True, coords_dtype=np.intp):
"""
This method changes the shape and size of an array in-place.
Expand Down
40 changes: 40 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,3 +1743,43 @@ def test_isinf_isnan(config):
expected = getattr(np, func_name)(arr)

np.testing.assert_equal(result, expected)


class TestSqueeze:
eye_arr = np.eye(2).reshape(1, 2, 1, 2)

@pytest.mark.parametrize(
"arr_and_axis",
[
(eye_arr, None),
(eye_arr, 0),
(eye_arr, 2),
(eye_arr, (0, 2)),
(np.zeros((5,)), None),
],
)
def test_squeeze(self, arr_and_axis):
arr, axis = arr_and_axis

s_arr = sparse.COO.from_numpy(arr)

result_1 = sparse.squeeze(s_arr, axis=axis).todense()
result_2 = s_arr.squeeze(axis=axis).todense()
expected = np.squeeze(arr, axis=axis)

np.testing.assert_equal(result_1, result_2)
np.testing.assert_equal(result_1, expected)

def test_squeeze_validation(self):
s_arr = sparse.COO.from_numpy(np.eye(3))

with pytest.raises(IndexError, match="tuple index out of range"):
s_arr.squeeze(3)

with pytest.raises(ValueError, match="Invalid axis parameter: `1.1`."):
s_arr.squeeze(1.1)

with pytest.raises(
ValueError, match="Specified axis `0` has a size greater than one: 3"
):
s_arr.squeeze(0)

0 comments on commit f522926

Please sign in to comment.