From 0422a7c8a20ef4f51d4fe715ad7184f5e396ab6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Fri, 22 Dec 2023 11:18:04 +0100 Subject: [PATCH] API: Refactor argmax and argmin --- sparse/_coo/common.py | 193 ++++++++++++++++++++++++++------------- sparse/tests/test_coo.py | 43 +++++---- 2 files changed, 153 insertions(+), 83 deletions(-) diff --git a/sparse/_coo/common.py b/sparse/_coo/common.py index 62500c26..c12b1491 100644 --- a/sparse/_coo/common.py +++ b/sparse/_coo/common.py @@ -2,7 +2,7 @@ import operator import warnings from collections.abc import Iterable -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import numpy as np import scipy.sparse @@ -613,65 +613,6 @@ def argwhere(a): return np.transpose(a.nonzero()) -def _arg_minmax_common( - x: SparseArray, - axis: Optional[int], - keepdims: bool, - comp_op: Callable, - np_arg_func: Callable, -): - """ """ - assert comp_op in (operator.lt, operator.gt) - assert np_arg_func in (np.argmax, np.argmin) - - if not isinstance(axis, (int, type(None))): - raise ValueError(f"axis must be int or None, but it's: {type(axis)}") - - if isinstance(axis, int) and axis >= x.ndim: - raise ValueError( - f"axis {axis} is out of bounds for array of dimension {x.ndim}" - ) - - if x.fill_value != 0.0: - raise ValueError( - f"Only 0.0 fill value is supported, but found: {x.fill_value}." - ) - - if np.any(comp_op(x.data, 0.0)): - raise ValueError( - f"None of the non-zero values can be {comp_op.__name__} the fill value." - ) - - # fast path - if axis is None or x.ndim == 1: - x_flat = x.reshape(-1) - result = x_flat.coords[0, np_arg_func(x_flat.data)] - return np.array(result).reshape([1] * x.ndim) if keepdims else result - - # search for min/max value & index for each retained axis - minmax_indexes = {} - minmax_values = {} - - for idx, coord in enumerate(x.coords.T): - coord = list(coord) - axis_index = coord[axis] - coord[axis] = 0 - coord = tuple(coord) - if not coord in minmax_values or comp_op(minmax_values[coord], x.data[idx]): - minmax_values[coord] = x.data[idx] - minmax_indexes[coord] = axis_index - - new_shape = list(x.shape) - new_shape[axis] = 1 - new_shape = tuple(new_shape) - - result = np.zeros(shape=new_shape, dtype=np.intp) - for idx, minmax_index in minmax_indexes.items(): - result[idx] = minmax_index - - return result if keepdims else result.squeeze() - - def argmax(x, /, *, axis=None, keepdims=False): """ Returns the indices of the maximum values along a specified axis. @@ -697,9 +638,7 @@ def argmax(x, /, *, axis=None, keepdims=False): the first occurrence of the maximum value. Otherwise, a non-zero-dimensional array containing the indices of the maximum values. """ - return _arg_minmax_common( - x, axis=axis, keepdims=keepdims, comp_op=operator.lt, np_arg_func=np.argmax - ) + return _arg_minmax_common(x, axis=axis, keepdims=keepdims, mode="max") def argmin(x, /, *, axis=None, keepdims=False): @@ -727,9 +666,7 @@ def argmin(x, /, *, axis=None, keepdims=False): the first occurrence of the minimum value. Otherwise, a non-zero-dimensional array containing the indices of the minimum values. """ - return _arg_minmax_common( - x, axis=axis, keepdims=keepdims, comp_op=operator.gt, np_arg_func=np.argmin - ) + return _arg_minmax_common(x, axis=axis, keepdims=keepdims, mode="min") def _replace_nan(array, value): @@ -1144,3 +1081,127 @@ def clip(a, a_min=None, a_max=None, out=None): """ a = asCOO(a, name="clip") return a.clip(a_min, a_max) + + +@numba.jit(nopython=True, nogil=True) +def _compute_minmax_args( + coords: np.ndarray, + data: np.ndarray, + reduce_size: int, + fill_value: float, + max_mode_flag: bool, +) -> Tuple[np.ndarray, np.ndarray]: + assert coords.shape[0] == 2 + reduce_coords = coords[0, :] + index_coords = coords[1, :] + + result_indices = np.unique(index_coords) + result_data = [] + + # we iterate through each trace + for idx in np.nditer(result_indices): + mask = index_coords == idx + masked_reduce_coords = reduce_coords[mask] + masked_data = data[mask] + + if max_mode_flag: + compared_data = operator.gt(masked_data, fill_value) + else: + compared_data = operator.lt(masked_data, fill_value) + + if np.any(compared_data) or len(masked_data) == reduce_size: + # best value is a non-fill value + best_arg = ( + np.argmax(masked_data) if max_mode_flag else np.argmin(masked_data) + ) + result_data.append(masked_reduce_coords[best_arg]) + else: + # best value is a fill value, find the first occurrence of it + current_coord = np.array(-1, dtype=coords.dtype) + found = False + for idx, new_coord in enumerate(np.nditer(np.sort(masked_reduce_coords))): + # there is at least one fill value between consecutive non-fill values + if new_coord - current_coord > 1: + result_data.append(idx) + found = True + break + current_coord = new_coord + # get the first fill value after all non-fill values + if not found: + result_data.append(current_coord + 1) + + return (result_indices, result_data) + + +def _arg_minmax_common( + x: SparseArray, + axis: Optional[int], + keepdims: bool, + mode: str, +): + """ + Internal implementation for argmax and argmin functions. + """ + assert mode in ("max", "min") + max_mode_flag = mode == "max" + + if not isinstance(axis, (int, type(None))): + raise ValueError(f"axis must be int or None, but it's: {type(axis)}") + if isinstance(axis, int) and axis >= x.ndim: + raise ValueError( + f"axis {axis} is out of bounds for array of dimension {x.ndim}" + ) + if x.ndim == 0: + raise ValueError("Input array must be at least 1-D, but it's 0-D.") + + # If `axis` is None then we need to flatten the input array and memorize + # the original dimensionality for the final reshape operation. + axis_none_original_ndim: Optional[int] = None + if axis is None: + axis_none_original_ndim = x.ndim + x = x.reshape(-1)[:, None] + axis = 0 + + # A 1-D array must have one more singleton dimension. + if axis == 0 and x.ndim == 1: + x = x[:, None] + + # We need to move `axis` to the front. + new_transpose = list(range(x.ndim)) + new_transpose.insert(0, new_transpose.pop(axis)) + new_transpose = tuple(new_transpose) + + # And reshape it to 2-D (reduce axis, the rest of axes flattened) + new_shape = list(x.shape) + new_shape.insert(0, new_shape.pop(axis)) + new_shape = tuple(new_shape) + + x = x.transpose(new_transpose) + x = x.reshape((new_shape[0], np.prod(new_shape[1:]))) + + # Compute max/min arguments + result_indices, result_data = _compute_minmax_args( + x.coords.copy(), + x.data.copy(), + reduce_size=x.shape[0], + fill_value=x.fill_value, + max_mode_flag=max_mode_flag, + ) + + from .core import COO + + result = COO( + result_indices, result_data, shape=(x.shape[1],), fill_value=0, prune=True + ) + + # Let's reshape the result to the original shape. + result = result.reshape((1, *new_shape[1:])) + new_transpose = list(range(result.ndim)) + new_transpose.insert(axis, new_transpose.pop(0)) + result = result.transpose(new_transpose) + + # If `axis=None` we need to reshape flattened array into original dimensionality. + if not axis_none_original_ndim is None: + result = result.reshape([1 for _ in range(axis_none_original_ndim)]) + + return result if keepdims else result.squeeze() diff --git a/sparse/tests/test_coo.py b/sparse/tests/test_coo.py index ebe8e10e..b64d8a06 100644 --- a/sparse/tests/test_coo.py +++ b/sparse/tests/test_coo.py @@ -1694,42 +1694,51 @@ def test_array_as_shape(): @pytest.mark.parametrize( "arr", - [np.array([[0, 3, 0], [1, 2, 0]]), np.array([[[0, 0], [1, 0]], [[5, 0], [0, 3]]])], + [np.array([[0, 3, 0], [1, 2, 0]]), np.array([[[0, 0], [1, 0]], [[5, 0], [0, -3]]])], ) @pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("keepdims", [True, False]) @pytest.mark.parametrize( - "mode", - [(sparse.argmax, np.argmax, lambda x: x), (sparse.argmin, np.argmin, lambda x: -x)], + "mode", [(sparse.argmax, np.argmax), (sparse.argmin, np.argmin)] ) def test_argmax_argmin(arr, axis, keepdims, mode): - sparse_func, np_func, transform = mode - arr = transform(arr) + sparse_func, np_func = mode s_arr = sparse.COO.from_numpy(arr) - result = sparse_func(s_arr, axis=axis, keepdims=keepdims) + result = sparse_func(s_arr, axis=axis, keepdims=keepdims).todense() expected = np_func(arr, axis=axis, keepdims=keepdims) np.testing.assert_equal(result, expected) -@pytest.mark.parametrize("func", [np.argmax, np.argmin]) -def test_argmax_argmin_value_constraint(func): - s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2) +@pytest.mark.parametrize("axis", [None, 0, 1, 2]) +@pytest.mark.parametrize( + "mode", [(sparse.argmax, np.argmax), (sparse.argmin, np.argmin)] +) +def test_argmax_argmin_3D(axis, mode): + sparse_func, np_func = mode - with pytest.raises( - ValueError, match="Only 0.0 fill value is supported, but found: 2." - ): - func(s) + s_arr = sparse.zeros(shape=(1000, 550, 3), format="dok") + s_arr[100, 100, 0] = 3 + s_arr[100, 100, 1] = 3 + s_arr[100, 99, 0] = -2 + s_arr = s_arr.to_coo() - arr = np.array([[-2, 0], [0, 2]]) - s = sparse.COO.from_numpy(arr) + result = sparse_func(s_arr, axis=axis).todense() + expected = np_func(s_arr.todense(), axis=axis) + + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize("func", [sparse.argmax, sparse.argmin]) +def test_argmax_argmin_constraint(func): + s = sparse.COO.from_numpy(np.full((2, 2), 2), fill_value=2) with pytest.raises( - ValueError, match=r"None of the non-zero values can be (lt|gt) the fill value" + ValueError, match="axis 2 is out of bounds for array of dimension 2" ): - func(s, axis=0) + func(s, axis=2) @pytest.mark.parametrize("config", [(np.inf, "isinf"), (np.nan, "isnan")])