diff --git a/sparse/_coo/common.py b/sparse/_coo/common.py index cb69d2be..cc23aff6 100644 --- a/sparse/_coo/common.py +++ b/sparse/_coo/common.py @@ -1197,7 +1197,9 @@ def sort(x, /, *, axis=-1, descending=False): x_shape = x.shape x = x.reshape((np.prod(x_shape[:-1]), x_shape[-1])) - _sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending) + _sort_coo( + x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending + ) x = x.reshape(x_shape[:-1] + (x_shape[-1],)) x = moveaxis(x, source=-1, destination=axis) @@ -1237,13 +1239,13 @@ def take(x, indices, /, *, axis=None): if axis is None: x = x.flatten() return x[indices] - + axis = normalize_axis(axis, x.ndim) full_index = (slice(None),) * axis + (indices, ...) return x[full_index] -def _validate_coo_input(x: Any) -> "COO": +def _validate_coo_input(x: Any): from .core import COO if isinstance(x, scipy.sparse.spmatrix): @@ -1269,31 +1271,36 @@ def _sort_coo( sort_coords = coords[1, :] result_indices = np.empty_like(sort_coords) - offset = 0 - - for uniq in np.unique(group_coords): - args = np.argwhere(group_coords == uniq).copy() - args = np.reshape(args, -1) - args = np.atleast_1d(args) - - fill_value_count = sort_axis_len - args.size - - if args.size > 1: - # np.sort in numba doesn't support `np.sort`'s arguments - # so `stable` can't be supported. + offset = 0 # tracks where the current group starts + + # iterate through all groups and sort each one of them + for unique_val in np.unique(group_coords): + # .copy() required by numba, as `reshape` expects a continous array + group = np.argwhere(group_coords == unique_val).copy() + group = np.reshape(group, -1) + group = np.atleast_1d(group) + + # SORT VALUES + if group.size > 1: + # np.sort in numba doesn't support `np.sort`'s arguments so `stable` + # keyword can't be supported. # https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods - data[args] = np.sort(data[args]) + data[group] = np.sort(data[group]) if descending: - data[args] = data[args][::-1] - - # define indices - indices = np.arange(args.size) - for pos in range(args.size): - if (fill_value < data[args][pos] and not descending) or (fill_value > data[args][pos] and descending): + data[group] = data[group][::-1] + + # SORT INDICES + fill_value_count = sort_axis_len - group.size + indices = np.arange(group.size) + # find a place where fill_value would be + for pos in range(group.size): + if ( + (not descending and fill_value < data[group][pos]) or + (descending and fill_value > data[group][pos]) + ): indices[pos:] += fill_value_count break - - result_indices[offset:offset+len(indices)] = indices + result_indices[offset : offset + len(indices)] = indices offset += len(indices) sort_coords[:] = result_indices diff --git a/sparse/tests/test_coo.py b/sparse/tests/test_coo.py index ba62ffb7..f27ffd7a 100644 --- a/sparse/tests/test_coo.py +++ b/sparse/tests/test_coo.py @@ -1779,15 +1779,18 @@ def test_input_validation(self, func): func(self.arr) -@pytest.mark.parametrize("arr", [ - np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64), - np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64) -]) +@pytest.mark.parametrize( + "arr", + [ + np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64), + np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64), + ], +) @pytest.mark.parametrize("fill_value", [-1, 0, 1, 3]) @pytest.mark.parametrize("axis", [0, 1, -1]) @pytest.mark.parametrize("descending", [False, True]) def test_sort(arr, fill_value, axis, descending): - s_arr = sparse.COO.from_numpy(arr, fill_value) + s_arr = sparse.COO.from_numpy(arr, fill_value) result = sparse.sort(s_arr, axis=axis, descending=descending) expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis) @@ -1798,10 +1801,19 @@ def test_sort(arr, fill_value, axis, descending): @pytest.mark.parametrize("fill_value", [-1, 0, 1, 3]) @pytest.mark.parametrize( "indices,axis", - [([1], 0,), ([2, 1], 1), ([1, 2, 3], 2), ([2, 3], -1), ([5, 3, 7, 8], None)] + [ + ( + [1], + 0, + ), + ([2, 1], 1), + ([1, 2, 3], 2), + ([2, 3], -1), + ([5, 3, 7, 8], None), + ], ) def test_take(fill_value, indices, axis): - arr = np.arange(24).reshape((2,3,4)) + arr = np.arange(24).reshape((2, 3, 4)) s_arr = sparse.COO.from_numpy(arr, fill_value)