Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and mtsokol committed Jan 12, 2024
1 parent 2cbe91b commit ff2084f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 31 deletions.
55 changes: 31 additions & 24 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,9 @@ def sort(x, /, *, axis=-1, descending=False):
x_shape = x.shape
x = x.reshape((np.prod(x_shape[:-1]), x_shape[-1]))

_sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending)
_sort_coo(
x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending
)

x = x.reshape(x_shape[:-1] + (x_shape[-1],))
x = moveaxis(x, source=-1, destination=axis)
Expand Down Expand Up @@ -1237,13 +1239,13 @@ def take(x, indices, /, *, axis=None):
if axis is None:
x = x.flatten()
return x[indices]

axis = normalize_axis(axis, x.ndim)
full_index = (slice(None),) * axis + (indices, ...)
return x[full_index]


def _validate_coo_input(x: Any) -> "COO":
def _validate_coo_input(x: Any):
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
Expand All @@ -1269,31 +1271,36 @@ def _sort_coo(
sort_coords = coords[1, :]

Check warning on line 1271 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1269-L1271

Added lines #L1269 - L1271 were not covered by tests

result_indices = np.empty_like(sort_coords)
offset = 0

for uniq in np.unique(group_coords):
args = np.argwhere(group_coords == uniq).copy()
args = np.reshape(args, -1)
args = np.atleast_1d(args)

fill_value_count = sort_axis_len - args.size

if args.size > 1:
# np.sort in numba doesn't support `np.sort`'s arguments
# so `stable` can't be supported.
offset = 0 # tracks where the current group starts

Check warning on line 1274 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1273-L1274

Added lines #L1273 - L1274 were not covered by tests

# iterate through all groups and sort each one of them
for unique_val in np.unique(group_coords):

Check warning on line 1277 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1277

Added line #L1277 was not covered by tests
# .copy() required by numba, as `reshape` expects a continous array
group = np.argwhere(group_coords == unique_val).copy()
group = np.reshape(group, -1)
group = np.atleast_1d(group)

Check warning on line 1281 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1279-L1281

Added lines #L1279 - L1281 were not covered by tests

# SORT VALUES
if group.size > 1:

Check warning on line 1284 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1284

Added line #L1284 was not covered by tests
# np.sort in numba doesn't support `np.sort`'s arguments so `stable`
# keyword can't be supported.
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
data[args] = np.sort(data[args])
data[group] = np.sort(data[group])
if descending:
data[args] = data[args][::-1]

# define indices
indices = np.arange(args.size)
for pos in range(args.size):
if (fill_value < data[args][pos] and not descending) or (fill_value > data[args][pos] and descending):
data[group] = data[group][::-1]

Check warning on line 1290 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1288-L1290

Added lines #L1288 - L1290 were not covered by tests

# SORT INDICES
fill_value_count = sort_axis_len - group.size
indices = np.arange(group.size)

Check warning on line 1294 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1293-L1294

Added lines #L1293 - L1294 were not covered by tests
# find a place where fill_value would be
for pos in range(group.size):
if (

Check warning on line 1297 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1296-L1297

Added lines #L1296 - L1297 were not covered by tests
(not descending and fill_value < data[group][pos]) or
(descending and fill_value > data[group][pos])
):
indices[pos:] += fill_value_count
break

result_indices[offset:offset+len(indices)] = indices
result_indices[offset : offset + len(indices)] = indices
offset += len(indices)

Check warning on line 1304 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1301-L1304

Added lines #L1301 - L1304 were not covered by tests

sort_coords[:] = result_indices

Check warning on line 1306 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1306

Added line #L1306 was not covered by tests
Expand Down
26 changes: 19 additions & 7 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,15 +1779,18 @@ def test_input_validation(self, func):
func(self.arr)


@pytest.mark.parametrize("arr", [
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64)
])
@pytest.mark.parametrize(
"arr",
[
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64),
],
)
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.parametrize("descending", [False, True])
def test_sort(arr, fill_value, axis, descending):
s_arr = sparse.COO.from_numpy(arr, fill_value)
s_arr = sparse.COO.from_numpy(arr, fill_value)

result = sparse.sort(s_arr, axis=axis, descending=descending)
expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis)
Expand All @@ -1798,10 +1801,19 @@ def test_sort(arr, fill_value, axis, descending):
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
@pytest.mark.parametrize(
"indices,axis",
[([1], 0,), ([2, 1], 1), ([1, 2, 3], 2), ([2, 3], -1), ([5, 3, 7, 8], None)]
[
(
[1],
0,
),
([2, 1], 1),
([1, 2, 3], 2),
([2, 3], -1),
([5, 3, 7, 8], None),
],
)
def test_take(fill_value, indices, axis):
arr = np.arange(24).reshape((2,3,4))
arr = np.arange(24).reshape((2, 3, 4))

s_arr = sparse.COO.from_numpy(arr, fill_value)

Expand Down

0 comments on commit ff2084f

Please sign in to comment.