Skip to content

Commit

Permalink
API: Add sort and take functions for COO format
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jan 12, 2024
1 parent f2d8f1d commit 8502cfa
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 21 deletions.
4 changes: 4 additions & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@
nansum,
result_type,
roll,
sort,
take,
tril,
triu,
unique_counts,
Expand Down Expand Up @@ -278,13 +280,15 @@
"sign",
"sin",
"sinh",
"sort",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"subtract",
"sum",
"take",
"tan",
"tanh",
"tensordot",
Expand Down
4 changes: 4 additions & 0 deletions sparse/_coo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
nansum,
result_type,
roll,
sort,
stack,
take,
tril,
triu,
unique_counts,
Expand All @@ -42,6 +44,8 @@
"nanmax",
"nanreduce",
"roll",
"sort",
"take",
"kron",
"argwhere",
"argmax",
Expand Down
175 changes: 154 additions & 21 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections.abc import Iterable
from functools import reduce
from typing import NamedTuple, Optional, Tuple
from typing import Any, NamedTuple, Optional, Tuple

import numba

Expand Down Expand Up @@ -1096,14 +1096,8 @@ def unique_counts(x, /):
>>> sparse.unique_counts(x)
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
"""
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)
x = _validate_coo_input(x)

x = x.flatten()
values, counts = np.unique(x.data, return_counts=True)
Expand Down Expand Up @@ -1143,6 +1137,113 @@ def unique_values(x, /):
>>> sparse.unique_values(x)
array([-3, 0, 1, 2])
"""

x = _validate_coo_input(x)

x = x.flatten()
values = np.unique(x.data)
if x.nnz < x.size:
values = np.sort(np.concatenate([[x.fill_value], values]))
return values


def sort(x, /, *, axis=-1, descending=False):
"""
Returns a sorted copy of an input array ``x``.
Parameters
----------
x : SparseArray
Input array. Should have a real-valued data type.
axis : int
Axis along which to sort. If set to ``-1``, the function must sort along
the last axis. Default: ``-1``.
descending : bool
Sort order. If ``True``, the array must be sorted in descending order (by value).
If ``False``, the array must be sorted in ascending order (by value).
Default: ``False``.
Returns
-------
out : COO
A sorted array.
Raises
------
ValueError
If the input array isn't and can't be converted to COO format.
Examples
--------
>>> import sparse
>>> x = sparse.COO.from_numpy([1, 0, 2, 0, 2, -3])
>>> sparse.sort(x).todense()
array([-3, 0, 0, 1, 2, 2])
>>> sparse.sort(x, descending=True).todense()
array([ 2, 2, 1, 0, 0, -3])
"""

from .._common import moveaxis

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
sparse._common
begins an import cycle.

x = _validate_coo_input(x)

original_ndim = x.ndim
if x.ndim == 1:
x = x[None, :]
axis = -1

Check warning on line 1194 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1193-L1194

Added lines #L1193 - L1194 were not covered by tests

x = moveaxis(x, source=axis, destination=-1)
x_shape = x.shape
x = x.reshape((np.prod(x_shape[:-1]), x_shape[-1]))

_sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending)

x = x.reshape(x_shape[:-1] + (x_shape[-1],))
x = moveaxis(x, source=-1, destination=axis)

return x if original_ndim == x.ndim else x.squeeze()


def take(x, indices, /, *, axis=None):
"""
Returns elements of an array along an axis.
Parameters
----------
x : SparseArray
Input array.
indices : ndarray
Array indices. The array must be one-dimensional and have an integer data type.
axis : int
Axis over which to select values. If ``axis`` is negative, the function must
determine the axis along which to select values by counting from the last dimension.
For ``None``, the flattened input array is used. Default: ``None``.
Returns
-------
out : COO
A COO array with requested indices.
Raises
------
ValueError
If the input array isn't and can't be converted to COO format.
"""

x = _validate_coo_input(x)

if axis is None:
x = x.flatten()
return x[indices]

axis = normalize_axis(axis, x.ndim)
full_index = (slice(None),) * axis + (indices, ...)
return x[full_index]


def _validate_coo_input(x: Any) -> "COO":
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
Expand All @@ -1152,11 +1253,50 @@ def unique_values(x, /):
elif not isinstance(x, COO):
x = x.asformat(COO)

x = x.flatten()
values = np.unique(x.data)
if x.nnz < x.size:
values = np.sort(np.concatenate([[x.fill_value], values]))
return values
return x


@numba.jit(nopython=True, nogil=True)
def _sort_coo(
coords: np.ndarray,
data: np.ndarray,
fill_value: float,
sort_axis_len: int,
descending: bool,
) -> None:
assert coords.shape[0] == 2
group_coords = coords[0, :]
sort_coords = coords[1, :]

Check warning on line 1269 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1267-L1269

Added lines #L1267 - L1269 were not covered by tests

result_indices = np.empty_like(sort_coords)
offset = 0

Check warning on line 1272 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1271-L1272

Added lines #L1271 - L1272 were not covered by tests

for uniq in np.unique(group_coords):
args = np.argwhere(group_coords == uniq).copy()
args = np.reshape(args, -1)
args = np.atleast_1d(args)

Check warning on line 1277 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1274-L1277

Added lines #L1274 - L1277 were not covered by tests

fill_value_count = sort_axis_len - args.size

Check warning on line 1279 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1279

Added line #L1279 was not covered by tests

if args.size > 1:

Check warning on line 1281 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1281

Added line #L1281 was not covered by tests
# np.sort in numba doesn't support `np.sort`'s arguments
# so `stable` can't be supported.
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
data[args] = np.sort(data[args])
if descending:
data[args] = data[args][::-1]

Check warning on line 1287 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1285-L1287

Added lines #L1285 - L1287 were not covered by tests

# define indices
indices = np.arange(args.size)
for pos in range(args.size):
if (fill_value < data[args][pos] and not descending) or (fill_value > data[args][pos] and descending):
indices[pos:] += fill_value_count
break

Check warning on line 1294 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1290-L1294

Added lines #L1290 - L1294 were not covered by tests

result_indices[offset:offset+len(indices)] = indices
offset += len(indices)

Check warning on line 1297 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1296-L1297

Added lines #L1296 - L1297 were not covered by tests

sort_coords[:] = result_indices

Check warning on line 1299 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1299

Added line #L1299 was not covered by tests


@numba.jit(nopython=True, nogil=True)
Expand Down Expand Up @@ -1216,14 +1356,7 @@ def _arg_minmax_common(
assert mode in ("max", "min")
max_mode_flag = mode == "max"

from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)
x = _validate_coo_input(x)

if not isinstance(axis, (int, type(None))):
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")
Expand Down
32 changes: 32 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,3 +1777,35 @@ def test_unique_values(self, arr, fill_value):
def test_input_validation(self, func):
with pytest.raises(ValueError, match=r"Input must be an instance of SparseArray"):
func(self.arr)


@pytest.mark.parametrize("arr", [
np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64),
np.array([[[2, 0], [0, 5]], [[1, 0], [4, 0]], [[0, 1], [0, -1]]], dtype=np.float64)
])
@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.parametrize("descending", [False, True])
def test_sort(arr, fill_value, axis, descending):
s_arr = sparse.COO.from_numpy(arr, fill_value)

result = sparse.sort(s_arr, axis=axis, descending=descending)
expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis)

np.testing.assert_equal(result.todense(), expected)


@pytest.mark.parametrize("fill_value", [-1, 0, 1, 3])
@pytest.mark.parametrize(
"indices,axis",
[([1], 0,), ([2, 1], 1), ([1, 2, 3], 2), ([2, 3], -1), ([5, 3, 7, 8], None)]
)
def test_take(fill_value, indices, axis):
arr = np.arange(24).reshape((2,3,4))

s_arr = sparse.COO.from_numpy(arr, fill_value)

result = sparse.take(s_arr, np.array(indices), axis=axis)
expected = np.take(arr, indices, axis)

np.testing.assert_equal(result.todense(), expected)
2 changes: 2 additions & 0 deletions sparse/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ def test_namespace():
"sign",
"sin",
"sinh",
"sort",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"subtract",
"sum",
"take",
"tan",
"tanh",
"tensordot",
Expand Down

0 comments on commit 8502cfa

Please sign in to comment.