Skip to content

Commit

Permalink
Merge branch 'master' into new-argmax-argmin
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi authored and mtsokol committed Jan 2, 2024
2 parents 0422a7c + f522926 commit cae0fe7
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
5 changes: 5 additions & 0 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,11 @@ def astype(x, dtype, /, *, copy=True):
return x.astype(dtype, copy=copy)


@_support_numpy
def squeeze(x, /, axis=None):
return x.squeeze(axis=axis)


@_support_numpy
def broadcast_to(x, /, shape):
return x.broadcast_to(shape)
Expand Down
6 changes: 3 additions & 3 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import operator
import warnings
from collections.abc import Iterable
from typing import Callable, Optional, Tuple
from typing import Optional, Tuple

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -1099,8 +1099,8 @@ def _compute_minmax_args(
result_data = []

Check warning on line 1099 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1098-L1099

Added lines #L1098 - L1099 were not covered by tests

# we iterate through each trace
for idx in np.nditer(result_indices):
mask = index_coords == idx
for result_index in np.nditer(result_indices):
mask = index_coords == result_index
masked_reduce_coords = reduce_coords[mask]
masked_data = data[mask]

Check warning on line 1105 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1102-L1105

Added lines #L1102 - L1105 were not covered by tests

Expand Down
53 changes: 53 additions & 0 deletions sparse/_coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,59 @@ def reshape(self, shape, order="C"):
self._cache["reshape"].append((shape, result))
return result

def squeeze(self, axis=None):
"""
Removes singleton dimensions (axes) from ``x``.
Parameters
----------
axis : Union[None, int, Tuple[int, ...]]
The axis (or axes) to squeeze. If a specified axis has a size greater than one,
a `ValueError` is raised. ``axis=None`` removes all singleton dimensions.
Default: ``None``.
Returns
-------
COO
The output array without ``axis`` dimensions.
Examples
--------
>>> s = COO.from_numpy(np.eye(2)).reshape((2, 1, 2, 1))
>>> s.squeeze().shape
(2, 2)
>>> s.squeeze(axis=1).shape
(2, 2, 1)
"""
squeezable_dims = tuple([d for d in range(self.ndim) if self.shape[d] == 1])

if axis is None:
axis = squeezable_dims
if isinstance(axis, int):
axis = (axis,)
elif isinstance(axis, Iterable):
axis = tuple(axis)
else:
raise ValueError(f"Invalid axis parameter: `{axis}`.")

for d in axis:
if not d in squeezable_dims:
raise ValueError(
f"Specified axis `{d}` has a size greater than one: {self.shape[d]}"
)

retained_dims = [d for d in range(self.ndim) if not d in axis]

coords = self.coords[retained_dims, :]
shape = tuple([s for idx, s in enumerate(self.shape) if idx in retained_dims])

return COO(
coords,
self.data,
shape,
has_duplicates=False,
sorted=True,
cache=self._cache is not None,
fill_value=self.fill_value,
)

def resize(self, *args, refcheck=True, coords_dtype=np.intp):
"""
This method changes the shape and size of an array in-place.
Expand Down
40 changes: 40 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,3 +1752,43 @@ def test_isinf_isnan(config):
expected = getattr(np, func_name)(arr)

np.testing.assert_equal(result, expected)


class TestSqueeze:
eye_arr = np.eye(2).reshape(1, 2, 1, 2)

@pytest.mark.parametrize(
"arr_and_axis",
[
(eye_arr, None),
(eye_arr, 0),
(eye_arr, 2),
(eye_arr, (0, 2)),
(np.zeros((5,)), None),
],
)
def test_squeeze(self, arr_and_axis):
arr, axis = arr_and_axis

s_arr = sparse.COO.from_numpy(arr)

result_1 = sparse.squeeze(s_arr, axis=axis).todense()
result_2 = s_arr.squeeze(axis=axis).todense()
expected = np.squeeze(arr, axis=axis)

np.testing.assert_equal(result_1, result_2)
np.testing.assert_equal(result_1, expected)

def test_squeeze_validation(self):
s_arr = sparse.COO.from_numpy(np.eye(3))

with pytest.raises(IndexError, match="tuple index out of range"):
s_arr.squeeze(3)

with pytest.raises(ValueError, match="Invalid axis parameter: `1.1`."):
s_arr.squeeze(1.1)

with pytest.raises(
ValueError, match="Specified axis `0` has a size greater than one: 3"
):
s_arr.squeeze(0)

0 comments on commit cae0fe7

Please sign in to comment.