Skip to content

Commit

Permalink
API: Refactor argmax and argmin (#614)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol authored Jan 2, 2024
1 parent f522926 commit b12d51d
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 83 deletions.
193 changes: 127 additions & 66 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import operator
import warnings
from collections.abc import Iterable
from typing import Callable, Optional
from typing import Optional, Tuple

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -613,65 +613,6 @@ def argwhere(a):
return np.transpose(a.nonzero())


def _arg_minmax_common(
x: SparseArray,
axis: Optional[int],
keepdims: bool,
comp_op: Callable,
np_arg_func: Callable,
):
""" """
assert comp_op in (operator.lt, operator.gt)
assert np_arg_func in (np.argmax, np.argmin)

if not isinstance(axis, (int, type(None))):
raise ValueError(f"axis must be int or None, but it's: {type(axis)}")

if isinstance(axis, int) and axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {x.ndim}"
)

if x.fill_value != 0.0:
raise ValueError(
f"Only 0.0 fill value is supported, but found: {x.fill_value}."
)

if np.any(comp_op(x.data, 0.0)):
raise ValueError(
f"None of the non-zero values can be {comp_op.__name__} the fill value."
)

# fast path
if axis is None or x.ndim == 1:
x_flat = x.reshape(-1)
result = x_flat.coords[0, np_arg_func(x_flat.data)]
return np.array(result).reshape([1] * x.ndim) if keepdims else result

# search for min/max value & index for each retained axis
minmax_indexes = {}
minmax_values = {}

for idx, coord in enumerate(x.coords.T):
coord = list(coord)
axis_index = coord[axis]
coord[axis] = 0
coord = tuple(coord)
if not coord in minmax_values or comp_op(minmax_values[coord], x.data[idx]):
minmax_values[coord] = x.data[idx]
minmax_indexes[coord] = axis_index

new_shape = list(x.shape)
new_shape[axis] = 1
new_shape = tuple(new_shape)

result = np.zeros(shape=new_shape, dtype=np.intp)
for idx, minmax_index in minmax_indexes.items():
result[idx] = minmax_index

return result if keepdims else result.squeeze()


def argmax(x, /, *, axis=None, keepdims=False):
"""
Returns the indices of the maximum values along a specified axis.
Expand All @@ -697,9 +638,7 @@ def argmax(x, /, *, axis=None, keepdims=False):
the first occurrence of the maximum value. Otherwise, a non-zero-dimensional
array containing the indices of the maximum values.
"""
return _arg_minmax_common(
x, axis=axis, keepdims=keepdims, comp_op=operator.lt, np_arg_func=np.argmax
)
return _arg_minmax_common(x, axis=axis, keepdims=keepdims, mode="max")


def argmin(x, /, *, axis=None, keepdims=False):
Expand Down Expand Up @@ -727,9 +666,7 @@ def argmin(x, /, *, axis=None, keepdims=False):
the first occurrence of the minimum value. Otherwise, a non-zero-dimensional
array containing the indices of the minimum values.
"""
return _arg_minmax_common(
x, axis=axis, keepdims=keepdims, comp_op=operator.gt, np_arg_func=np.argmin
)
return _arg_minmax_common(x, axis=axis, keepdims=keepdims, mode="min")


def _replace_nan(array, value):
Expand Down Expand Up @@ -1144,3 +1081,127 @@ def clip(a, a_min=None, a_max=None, out=None):
"""
a = asCOO(a, name="clip")
return a.clip(a_min, a_max)


@numba.jit(nopython=True, nogil=True)
def _compute_minmax_args(
coords: np.ndarray,
data: np.ndarray,
reduce_size: int,
fill_value: float,
max_mode_flag: bool,
) -> Tuple[np.ndarray, np.ndarray]:
assert coords.shape[0] == 2
reduce_coords = coords[0, :]
index_coords = coords[1, :]

result_indices = np.unique(index_coords)
result_data = []

# we iterate through each trace
for result_index in np.nditer(result_indices):
mask = index_coords == result_index
masked_reduce_coords = reduce_coords[mask]
masked_data = data[mask]

if max_mode_flag:
compared_data = operator.gt(masked_data, fill_value)
else:
compared_data = operator.lt(masked_data, fill_value)

if np.any(compared_data) or len(masked_data) == reduce_size:
# best value is a non-fill value
best_arg = (
np.argmax(masked_data) if max_mode_flag else np.argmin(masked_data)
)
result_data.append(masked_reduce_coords[best_arg])
else:
# best value is a fill value, find the first occurrence of it
current_coord = np.array(-1, dtype=coords.dtype)
found = False
for idx, new_coord in enumerate(np.nditer(np.sort(masked_reduce_coords))):
# there is at least one fill value between consecutive non-fill values
if new_coord - current_coord > 1:
result_data.append(idx)
found = True
break
current_coord = new_coord
# get the first fill value after all non-fill values
if not found:
result_data.append(current_coord + 1)

return (result_indices, result_data)


def _arg_minmax_common(
x: SparseArray,
axis: Optional[int],
keepdims: bool,
mode: str,
):
"""
Internal implementation for argmax and argmin functions.
"""
assert mode in ("max", "min")
max_mode_flag = mode == "max"

if not isinstance(axis, (int, type(None))):
raise ValueError(f"axis must be int or None, but it's: {type(axis)}")
if isinstance(axis, int) and axis >= x.ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {x.ndim}"
)
if x.ndim == 0:
raise ValueError("Input array must be at least 1-D, but it's 0-D.")

# If `axis` is None then we need to flatten the input array and memorize
# the original dimensionality for the final reshape operation.
axis_none_original_ndim: Optional[int] = None
if axis is None:
axis_none_original_ndim = x.ndim
x = x.reshape(-1)[:, None]
axis = 0

# A 1-D array must have one more singleton dimension.
if axis == 0 and x.ndim == 1:
x = x[:, None]

# We need to move `axis` to the front.
new_transpose = list(range(x.ndim))
new_transpose.insert(0, new_transpose.pop(axis))
new_transpose = tuple(new_transpose)

# And reshape it to 2-D (reduce axis, the rest of axes flattened)
new_shape = list(x.shape)
new_shape.insert(0, new_shape.pop(axis))
new_shape = tuple(new_shape)

x = x.transpose(new_transpose)
x = x.reshape((new_shape[0], np.prod(new_shape[1:])))

# Compute max/min arguments
result_indices, result_data = _compute_minmax_args(
x.coords.copy(),
x.data.copy(),
reduce_size=x.shape[0],
fill_value=x.fill_value,
max_mode_flag=max_mode_flag,
)

from .core import COO

result = COO(
result_indices, result_data, shape=(x.shape[1],), fill_value=0, prune=True
)

# Let's reshape the result to the original shape.
result = result.reshape((1, *new_shape[1:]))
new_transpose = list(range(result.ndim))
new_transpose.insert(axis, new_transpose.pop(0))
result = result.transpose(new_transpose)

# If `axis=None` we need to reshape flattened array into original dimensionality.
if not axis_none_original_ndim is None:
result = result.reshape([1 for _ in range(axis_none_original_ndim)])

return result if keepdims else result.squeeze()
43 changes: 26 additions & 17 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,42 +1694,51 @@ def test_array_as_shape():

@pytest.mark.parametrize(
"arr",
[np.array([[0, 3, 0], [1, 2, 0]]), np.array([[[0, 0], [1, 0]], [[5, 0], [0, 3]]])],
[np.array([[0, 3, 0], [1, 2, 0]]), np.array([[[0, 0], [1, 0]], [[5, 0], [0, -3]]])],
)
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize(
"mode",
[(sparse.argmax, np.argmax, lambda x: x), (sparse.argmin, np.argmin, lambda x: -x)],
"mode", [(sparse.argmax, np.argmax), (sparse.argmin, np.argmin)]
)
def test_argmax_argmin(arr, axis, keepdims, mode):
sparse_func, np_func, transform = mode
arr = transform(arr)
sparse_func, np_func = mode

s_arr = sparse.COO.from_numpy(arr)

result = sparse_func(s_arr, axis=axis, keepdims=keepdims)
result = sparse_func(s_arr, axis=axis, keepdims=keepdims).todense()
expected = np_func(arr, axis=axis, keepdims=keepdims)

np.testing.assert_equal(result, expected)


@pytest.mark.parametrize("func", [np.argmax, np.argmin])
def test_argmax_argmin_value_constraint(func):
s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2)
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
@pytest.mark.parametrize(
"mode", [(sparse.argmax, np.argmax), (sparse.argmin, np.argmin)]
)
def test_argmax_argmin_3D(axis, mode):
sparse_func, np_func = mode

with pytest.raises(
ValueError, match="Only 0.0 fill value is supported, but found: 2."
):
func(s)
s_arr = sparse.zeros(shape=(1000, 550, 3), format="dok")
s_arr[100, 100, 0] = 3
s_arr[100, 100, 1] = 3
s_arr[100, 99, 0] = -2
s_arr = s_arr.to_coo()

arr = np.array([[-2, 0], [0, 2]])
s = sparse.COO.from_numpy(arr)
result = sparse_func(s_arr, axis=axis).todense()
expected = np_func(s_arr.todense(), axis=axis)

np.testing.assert_equal(result, expected)


@pytest.mark.parametrize("func", [sparse.argmax, sparse.argmin])
def test_argmax_argmin_constraint(func):
s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2)

with pytest.raises(
ValueError, match=r"None of the non-zero values can be (lt|gt) the fill value"
ValueError, match="axis 2 is out of bounds for array of dimension 2"
):
func(s, axis=0)
func(s, axis=2)


@pytest.mark.parametrize("config", [(np.inf, "isinf"), (np.nan, "isnan")])
Expand Down

0 comments on commit b12d51d

Please sign in to comment.