Skip to content

Commit ff2084f

Browse files
pre-commit-ci[bot]mtsokol
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2cbe91b commit ff2084f

File tree

2 files changed

+50
-31
lines changed

2 files changed

+50
-31
lines changed

sparse/_coo/common.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,9 @@ def sort(x, /, *, axis=-1, descending=False):
11971197
x_shape = x.shape
11981198
x = x.reshape((np.prod(x_shape[:-1]), x_shape[-1]))
11991199

1200-
_sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending)
1200+
_sort_coo(
1201+
x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending
1202+
)
12011203

12021204
x = x.reshape(x_shape[:-1] + (x_shape[-1],))
12031205
x = moveaxis(x, source=-1, destination=axis)
@@ -1237,13 +1239,13 @@ def take(x, indices, /, *, axis=None):
12371239
if axis is None:
12381240
x = x.flatten()
12391241
return x[indices]
1240-
1242+
12411243
axis = normalize_axis(axis, x.ndim)
12421244
full_index = (slice(None),) * axis + (indices, ...)
12431245
return x[full_index]
12441246

12451247

1246-
def _validate_coo_input(x: Any) -> "COO":
1248+
def _validate_coo_input(x: Any):
12471249
from .core import COO
12481250

12491251
if isinstance(x, scipy.sparse.spmatrix):
@@ -1269,31 +1271,36 @@ def _sort_coo(
12691271
sort_coords = coords[1, :]
12701272

12711273
result_indices = np.empty_like(sort_coords)
1272-
offset = 0
1273-
1274-
for uniq in np.unique(group_coords):
1275-
args = np.argwhere(group_coords == uniq).copy()
1276-
args = np.reshape(args, -1)
1277-
args = np.atleast_1d(args)
1278-
1279-
fill_value_count = sort_axis_len - args.size
1280-
1281-
if args.size > 1:
1282-
# np.sort in numba doesn't support `np.sort`'s arguments
1283-
# so `stable` can't be supported.
1274+
offset = 0 # tracks where the current group starts
1275+
1276+
# iterate through all groups and sort each one of them
1277+
for unique_val in np.unique(group_coords):
1278+
# .copy() required by numba, as `reshape` expects a continous array
1279+
group = np.argwhere(group_coords == unique_val).copy()
1280+
group = np.reshape(group, -1)
1281+
group = np.atleast_1d(group)
1282+
1283+
# SORT VALUES
1284+
if group.size > 1:
1285+
# np.sort in numba doesn't support `np.sort`'s arguments so `stable`
1286+
# keyword can't be supported.
12841287
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
1285-
data[args] = np.sort(data[args])
1288+
data[group] = np.sort(data[group])
12861289
if descending:
1287-
data[args] = data[args][::-1]
1288-
1289-
# define indices
1290-
indices = np.arange(args.size)
1291-
for pos in range(args.size):
1292-
if (fill_value < data[args][pos] and not descending) or (fill_value > data[args][pos] and descending):
1290+
data[group] = data[group][::-1]
1291+
1292+
# SORT INDICES
1293+
fill_value_count = sort_axis_len - group.size
1294+
indices = np.arange(group.size)
1295+
# find a place where fill_value would be
1296+
for pos in range(group.size):
1297+
if (
1298+
(not descending and fill_value < data[group][pos]) or
1299+
(descending and fill_value > data[group][pos])
1300+
):
12931301
indices[pos:] += fill_value_count
12941302
break
1295-
1296-
result_indices[offset:offset+len(indices)] = indices
1303+
result_indices[offset : offset + len(indices)] = indices
12971304
offset += len(indices)
12981305

12991306
sort_coords[:] = result_indices

sparse/tests/test_coo.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,15 +1779,18 @@ def test_input_validation(self, func):
17791779
func(self.arr)
17801780

17811781

1782-
@pytest.mark.parametrize("arr", [
1783-
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
1784-
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64)
1785-
])
1782+
@pytest.mark.parametrize(
1783+
"arr",
1784+
[
1785+
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
1786+
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64),
1787+
],
1788+
)
17861789
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
17871790
@pytest.mark.parametrize("axis", [0, 1, -1])
17881791
@pytest.mark.parametrize("descending", [False, True])
17891792
def test_sort(arr, fill_value, axis, descending):
1790-
s_arr = sparse.COO.from_numpy(arr, fill_value)
1793+
s_arr = sparse.COO.from_numpy(arr, fill_value)
17911794

17921795
result = sparse.sort(s_arr, axis=axis, descending=descending)
17931796
expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis)
@@ -1798,10 +1801,19 @@ def test_sort(arr, fill_value, axis, descending):
17981801
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
17991802
@pytest.mark.parametrize(
18001803
"indices,axis",
1801-
[([1], 0,), ([2, 1], 1), ([1, 2, 3], 2), ([2, 3], -1), ([5, 3, 7, 8], None)]
1804+
[
1805+
(
1806+
[1],
1807+
0,
1808+
),
1809+
([2, 1], 1),
1810+
([1, 2, 3], 2),
1811+
([2, 3], -1),
1812+
([5, 3, 7, 8], None),
1813+
],
18021814
)
18031815
def test_take(fill_value, indices, axis):
1804-
arr = np.arange(24).reshape((2,3,4))
1816+
arr = np.arange(24).reshape((2, 3, 4))
18051817

18061818
s_arr = sparse.COO.from_numpy(arr, fill_value)
18071819

0 commit comments

Comments
 (0)