Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Refactor argmax and argmin #614

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 127 additions & 66 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import operator
import warnings
from collections.abc import Iterable
from typing import Callable, Optional
from typing import Optional, Tuple

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -613,65 +613,6 @@
return np.transpose(a.nonzero())


def _arg_minmax_common(
x: SparseArray,
axis: Optional[int],
keepdims: bool,
comp_op: Callable,
np_arg_func: Callable,
):
""" """
assert comp_op in (operator.lt, operator.gt)
assert np_arg_func in (np.argmax, np.argmin)

if not isinstance(axis, (int, type(None))):
raise ValueError(f"axis must be int or None, but it's: {type(axis)}")

if isinstance(axis, int) and axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {x.ndim}"
)

if x.fill_value != 0.0:
raise ValueError(
f"Only 0.0 fill value is supported, but found: {x.fill_value}."
)

if np.any(comp_op(x.data, 0.0)):
raise ValueError(
f"None of the non-zero values can be {comp_op.__name__} the fill value."
)

# fast path
if axis is None or x.ndim == 1:
x_flat = x.reshape(-1)
result = x_flat.coords[0, np_arg_func(x_flat.data)]
return np.array(result).reshape([1] * x.ndim) if keepdims else result

# search for min/max value & index for each retained axis
minmax_indexes = {}
minmax_values = {}

for idx, coord in enumerate(x.coords.T):
coord = list(coord)
axis_index = coord[axis]
coord[axis] = 0
coord = tuple(coord)
if not coord in minmax_values or comp_op(minmax_values[coord], x.data[idx]):
minmax_values[coord] = x.data[idx]
minmax_indexes[coord] = axis_index

new_shape = list(x.shape)
new_shape[axis] = 1
new_shape = tuple(new_shape)

result = np.zeros(shape=new_shape, dtype=np.intp)
for idx, minmax_index in minmax_indexes.items():
result[idx] = minmax_index

return result if keepdims else result.squeeze()


def argmax(x, /, *, axis=None, keepdims=False):
"""
Returns the indices of the maximum values along a specified axis.
Expand All @@ -697,9 +638,7 @@
the first occurrence of the maximum value. Otherwise, a non-zero-dimensional
array containing the indices of the maximum values.
"""
return _arg_minmax_common(
x, axis=axis, keepdims=keepdims, comp_op=operator.lt, np_arg_func=np.argmax
)
return _arg_minmax_common(x, axis=axis, keepdims=keepdims, mode="max")


def argmin(x, /, *, axis=None, keepdims=False):
Expand Down Expand Up @@ -727,9 +666,7 @@
the first occurrence of the minimum value. Otherwise, a non-zero-dimensional
array containing the indices of the minimum values.
"""
return _arg_minmax_common(
x, axis=axis, keepdims=keepdims, comp_op=operator.gt, np_arg_func=np.argmin
)
return _arg_minmax_common(x, axis=axis, keepdims=keepdims, mode="min")


def _replace_nan(array, value):
Expand Down Expand Up @@ -1144,3 +1081,127 @@
"""
a = asCOO(a, name="clip")
return a.clip(a_min, a_max)


@numba.jit(nopython=True, nogil=True)
def _compute_minmax_args(
coords: np.ndarray,
data: np.ndarray,
reduce_size: int,
fill_value: float,
max_mode_flag: bool,
) -> Tuple[np.ndarray, np.ndarray]:
assert coords.shape[0] == 2
reduce_coords = coords[0, :]
index_coords = coords[1, :]

Check warning on line 1096 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1094-L1096

Added lines #L1094 - L1096 were not covered by tests

result_indices = np.unique(index_coords)
result_data = []

Check warning on line 1099 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1098-L1099

Added lines #L1098 - L1099 were not covered by tests

# we iterate through each trace
for result_index in np.nditer(result_indices):
mask = index_coords == result_index
masked_reduce_coords = reduce_coords[mask]
masked_data = data[mask]

Check warning on line 1105 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1102-L1105

Added lines #L1102 - L1105 were not covered by tests

if max_mode_flag:
compared_data = operator.gt(masked_data, fill_value)

Check warning on line 1108 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1107-L1108

Added lines #L1107 - L1108 were not covered by tests
else:
compared_data = operator.lt(masked_data, fill_value)

Check warning on line 1110 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1110

Added line #L1110 was not covered by tests

if np.any(compared_data) or len(masked_data) == reduce_size:

Check warning on line 1112 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1112

Added line #L1112 was not covered by tests
# best value is a non-fill value
best_arg = (

Check warning on line 1114 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1114

Added line #L1114 was not covered by tests
np.argmax(masked_data) if max_mode_flag else np.argmin(masked_data)
)
result_data.append(masked_reduce_coords[best_arg])

Check warning on line 1117 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1117

Added line #L1117 was not covered by tests
else:
# best value is a fill value, find the first occurrence of it
current_coord = np.array(-1, dtype=coords.dtype)
found = False
for idx, new_coord in enumerate(np.nditer(np.sort(masked_reduce_coords))):

Check warning on line 1122 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1120-L1122

Added lines #L1120 - L1122 were not covered by tests
Fixed Show fixed Hide fixed
# there is at least one fill value between consecutive non-fill values
if new_coord - current_coord > 1:
result_data.append(idx)
found = True
break
current_coord = new_coord

Check warning on line 1128 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1124-L1128

Added lines #L1124 - L1128 were not covered by tests
# get the first fill value after all non-fill values
if not found:
result_data.append(current_coord + 1)

Check warning on line 1131 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1130-L1131

Added lines #L1130 - L1131 were not covered by tests

return (result_indices, result_data)

Check warning on line 1133 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1133

Added line #L1133 was not covered by tests


def _arg_minmax_common(
x: SparseArray,
axis: Optional[int],
keepdims: bool,
mode: str,
):
"""
Internal implementation for argmax and argmin functions.
"""
assert mode in ("max", "min")
max_mode_flag = mode == "max"

if not isinstance(axis, (int, type(None))):
raise ValueError(f"axis must be int or None, but it's: {type(axis)}")

Check warning on line 1149 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1149

Added line #L1149 was not covered by tests
if isinstance(axis, int) and axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {x.ndim}"
)
if x.ndim == 0:
raise ValueError("Input array must be at least 1-D, but it's 0-D.")

Check warning on line 1155 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1155

Added line #L1155 was not covered by tests

# If `axis` is None then we need to flatten the input array and memorize
# the original dimensionality for the final reshape operation.
axis_none_original_ndim: Optional[int] = None
if axis is None:
axis_none_original_ndim = x.ndim
x = x.reshape(-1)[:, None]
axis = 0

# A 1-D array must have one more singleton dimension.
if axis == 0 and x.ndim == 1:
x = x[:, None]

Check warning on line 1167 in sparse/_coo/common.py

View check run for this annotation

Codecov / codecov/patch

sparse/_coo/common.py#L1167

Added line #L1167 was not covered by tests

# We need to move `axis` to the front.
new_transpose = list(range(x.ndim))
new_transpose.insert(0, new_transpose.pop(axis))
new_transpose = tuple(new_transpose)

# And reshape it to 2-D (reduce axis, the rest of axes flattened)
new_shape = list(x.shape)
new_shape.insert(0, new_shape.pop(axis))
new_shape = tuple(new_shape)

x = x.transpose(new_transpose)
x = x.reshape((new_shape[0], np.prod(new_shape[1:])))

# Compute max/min arguments
result_indices, result_data = _compute_minmax_args(
x.coords.copy(),
x.data.copy(),
reduce_size=x.shape[0],
fill_value=x.fill_value,
max_mode_flag=max_mode_flag,
)

from .core import COO
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved

result = COO(
result_indices, result_data, shape=(x.shape[1],), fill_value=0, prune=True
)

# Let's reshape the result to the original shape.
result = result.reshape((1, *new_shape[1:]))
new_transpose = list(range(result.ndim))
new_transpose.insert(axis, new_transpose.pop(0))
result = result.transpose(new_transpose)

# If `axis=None` we need to reshape flattened array into original dimensionality.
if not axis_none_original_ndim is None:
result = result.reshape([1 for _ in range(axis_none_original_ndim)])

return result if keepdims else result.squeeze()
43 changes: 26 additions & 17 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,42 +1694,51 @@ def test_array_as_shape():

@pytest.mark.parametrize(
"arr",
[np.array([[0, 3, 0], [1, 2, 0]]), np.array([[[0, 0], [1, 0]], [[5, 0], [0, 3]]])],
[np.array([[0, 3, 0], [1, 2, 0]]), np.array([[[0, 0], [1, 0]], [[5, 0], [0, -3]]])],
)
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize(
"mode",
[(sparse.argmax, np.argmax, lambda x: x), (sparse.argmin, np.argmin, lambda x: -x)],
"mode", [(sparse.argmax, np.argmax), (sparse.argmin, np.argmin)]
)
def test_argmax_argmin(arr, axis, keepdims, mode):
sparse_func, np_func, transform = mode
arr = transform(arr)
sparse_func, np_func = mode

s_arr = sparse.COO.from_numpy(arr)

result = sparse_func(s_arr, axis=axis, keepdims=keepdims)
result = sparse_func(s_arr, axis=axis, keepdims=keepdims).todense()
expected = np_func(arr, axis=axis, keepdims=keepdims)

np.testing.assert_equal(result, expected)


@pytest.mark.parametrize("func", [np.argmax, np.argmin])
def test_argmax_argmin_value_constraint(func):
s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2)
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
@pytest.mark.parametrize(
"mode", [(sparse.argmax, np.argmax), (sparse.argmin, np.argmin)]
)
def test_argmax_argmin_3D(axis, mode):
sparse_func, np_func = mode

with pytest.raises(
ValueError, match="Only 0.0 fill value is supported, but found: 2."
):
func(s)
s_arr = sparse.zeros(shape=(1000, 550, 3), format="dok")
s_arr[100, 100, 0] = 3
s_arr[100, 100, 1] = 3
s_arr[100, 99, 0] = -2
s_arr = s_arr.to_coo()

arr = np.array([[-2, 0], [0, 2]])
s = sparse.COO.from_numpy(arr)
result = sparse_func(s_arr, axis=axis).todense()
expected = np_func(s_arr.todense(), axis=axis)

np.testing.assert_equal(result, expected)


@pytest.mark.parametrize("func", [sparse.argmax, sparse.argmin])
def test_argmax_argmin_constraint(func):
s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2)

with pytest.raises(
ValueError, match=r"None of the non-zero values can be (lt|gt) the fill value"
ValueError, match="axis 2 is out of bounds for array of dimension 2"
):
func(s, axis=0)
func(s, axis=2)


@pytest.mark.parametrize("config", [(np.inf, "isinf"), (np.nan, "isnan")])
Expand Down
Loading