diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index c085437f..9dbb2417 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -151,10 +151,16 @@ def expand_dims(x, /, *, axis): chunks = tuple(1 if i in axis else next(chunks_it) for i in range(ndim_new)) return map_blocks( - nxp.expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis + _expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis ) +def _expand_dims(a, *args, **kwargs): + if isinstance(a, dict): + return {k: nxp.expand_dims(v, *args, **kwargs) for k, v in a.items()} + return nxp.expand_dims(a, *args, **kwargs) + + def flatten(x): return reshape(x, (-1,)) diff --git a/cubed/core/groupby.py b/cubed/core/groupby.py new file mode 100644 index 00000000..726906d2 --- /dev/null +++ b/cubed/core/groupby.py @@ -0,0 +1,118 @@ +from typing import TYPE_CHECKING + +from cubed.array_api.manipulation_functions import broadcast_to, expand_dims +from cubed.backend_array_api import namespace as nxp +from cubed.core.ops import map_blocks, reduction_new + +if TYPE_CHECKING: + from cubed.array_api.array_object import Array + + +def groupby_reduction( + x: "Array", + by: "Array", + func, + combine_func=None, + aggregate_func=None, + axis=None, + intermediate_dtype=None, + dtype=None, + keepdims=False, + split_every=None, + num_groups=None, + extra_func_kwargs=None, +) -> "Array": + """A reduction that performs groupby aggregations. + + Parameters + ---------- + x: Array + Array being grouped along one axis. + by: Array + Array of non-negative integers to be used as labels with which to group + the values in ``x`` along the reduction axis. Must be a 1D array. + func: callable + Function to apply to each chunk of data before reduction. + combine_func: callable + Function which may be applied recursively to intermediate chunks of + data. The number of chunks that are combined in each round is + determined by the ``split_every`` parameter. The output of the + function is a chunk with size ``num_groups`` along the reduction axis. + aggregate_func: callable, optional + Function to apply to each of the final chunks to produce the final output. + axis: int or sequence of ints, optional + Axis to aggregate along. Only supports a single axis. + intermediate_dtype: dtype + Data type of intermediate output. + dtype: dtype + Data type of output. + keepdims: boolean, optional + Whether the reduction function should preserve the reduced axes, + or remove them. + split_every: int >= 2 or dict(axis: int), optional + The number of chunks to combine in one round along each axis in the + recursive aggregation. + num_groups: int + The number of groups in the grouping array ``by``. + extra_func_kwargs: dict, optional + Extra keyword arguments to pass to ``func`` and ``combine_func``. + """ + + if isinstance(axis, tuple): + if len(axis) != 1: + raise ValueError( + f"Only a single axis is supported for groupby_reduction: {axis}" + ) + axis = axis[0] + + # make sure 'by' has corresponding blocks to 'x' + for ax in range(x.ndim): + if ax != axis: + by = expand_dims(by, axis=ax) + by_chunks = tuple( + c if i == axis else (1,) * x.numblocks[i] for i, c in enumerate(by.chunks) + ) + by_shape = tuple(map(sum, by_chunks)) + by = broadcast_to(by, by_shape, chunks=by_chunks) + + # wrapper to squeeze 'by' to undo effect of broadcast, so it looks same + # to user supplied func + def _group_reduction_func_wrapper(func): + def wrapper(a, by, **kwargs): + return func(a, nxp.squeeze(by), **kwargs) + + return wrapper + + # initial map does group reduction on each block + chunks = tuple( + (num_groups,) * len(c) if i == axis else c for i, c in enumerate(x.chunks) + ) + out = map_blocks( + _group_reduction_func_wrapper(func), + x, + by, + dtype=intermediate_dtype, + chunks=chunks, + axis=axis, + intermediate_dtype=intermediate_dtype, + num_groups=num_groups, + ) + + # add a dummy dimension to reduce over + dummy_axis = -1 + out = expand_dims(out, axis=dummy_axis) + + # then reduce across blocks + return reduction_new( + out, + func=None, + combine_func=combine_func, + aggregate_func=aggregate_func, + axis=(dummy_axis, axis), # dummy and group axis + intermediate_dtype=intermediate_dtype, + dtype=dtype, + keepdims=keepdims, + split_every=split_every, + combine_sizes={axis: num_groups}, # group axis doesn't have size 1 + extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis), + ) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index a8fa6acb..240998a2 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -909,14 +909,14 @@ def reduction( return reduction_new( x, func, - combine_func, - aggregate_func, - axis, - intermediate_dtype, - dtype, - keepdims, - split_every, - extra_func_kwargs, + combine_func=combine_func, + aggregate_func=aggregate_func, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + keepdims=keepdims, + split_every=split_every, + extra_func_kwargs=extra_func_kwargs, ) if combine_func is None: combine_func = func @@ -1017,10 +1017,49 @@ def reduction_new( dtype=None, keepdims=False, split_every=None, + combine_sizes=None, extra_func_kwargs=None, ) -> "Array": - """Apply a function to reduce an array along one or more axes.""" + """Apply a function to reduce an array along one or more axes. + + Parameters + ---------- + x: Array + Array being reduced along one or more axes. + func: callable + Function to apply to each chunk of data before reduction. + combine_func: callable, optional + Function which may be applied recursively to intermediate chunks of + data. The number of chunks that are combined in each round is + determined by the ``split_every`` parameter. The output of the + function is a chunk with size one (or the size specified in + ``combine_sizes``) in each of the reduction axes. If omitted, + it defaults to ``func``. + aggregate_func: callable, optional + Function to apply to each of the final chunks to produce the final output. + axis: int or sequence of ints, optional + Axis or axes to aggregate upon. If omitted, aggregate along all axes. + intermediate_dtype: dtype + Data type of intermediate output. + dtype: dtype + Data type of output. + keepdims: boolean, optional + Whether the reduction function should preserve the reduced axes, + or remove them. + split_every: int >= 2 or dict(axis: int), optional + The number of chunks to combine in one round along each axis in the + recursive aggregation. + combine_sizes: dict(axis: int), optional + The resulting size of each axis after reduction. Each reduction axis + defaults to size one if not specified. + extra_func_kwargs: dict, optional + Extra keyword arguments to pass to ``func`` and ``combine_func``. + """ if combine_func is None: + if func is None: + raise ValueError( + "At least one of `func` and `combine_func` must be specified in reduction" + ) combine_func = func if axis is None: axis = tuple(range(x.ndim)) @@ -1032,14 +1071,19 @@ def reduction_new( split_every = _normalize_split_every(split_every, axis) + if func is None: + initial_func = None + else: + initial_func = partial( + func, axis=axis, keepdims=True, **(extra_func_kwargs or {}) + ) result = partial_reduce( x, partial(combine_func, **(extra_func_kwargs or {})), - initial_func=partial( - func, axis=axis, keepdims=True, **(extra_func_kwargs or {}) - ), + initial_func=initial_func, split_every=split_every, dtype=intermediate_dtype, + combine_sizes=combine_sizes, ) # combine intermediates @@ -1049,6 +1093,7 @@ def reduction_new( axis=axis, dtype=intermediate_dtype, split_every=split_every, + combine_sizes=combine_sizes, ) # aggregate final chunks @@ -1085,6 +1130,7 @@ def tree_reduce( axis, dtype, split_every=None, + combine_sizes=None, ): """Apply a reduction function repeatedly across multiple axes.""" if axis is None: @@ -1105,11 +1151,19 @@ def tree_reduce( func, split_every=split_every, dtype=dtype, + combine_sizes=combine_sizes, ) return x -def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None): +def partial_reduce( + x, + func, + initial_func=None, + split_every=None, + dtype=None, + combine_sizes=None, +): """Apply a reduction function to multiple blocks across multiple axes. Parameters @@ -1118,21 +1172,30 @@ def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None): Array being reduced along one or more axes func: callable Reduction function to apply to each chunk of data, resulting in a chunk - with size one in each of the reduction axes. + with size one (or the size specified in ``combine_sizes``) in each of + the reduction axes. initial_func: callable, optional Function to apply to each chunk of data before reduction. split_every: int >= 2 or dict(axis: int), optional - The depth of the recursive aggregation. - dtype: DType + The number of chunks to combine in one round along each axis in the + recursive aggregation. + dtype: dtype Output data type. + combine_sizes: dict(axis: int), optional + The resulting size of each axis after reduction. Each reduction axis + defaults to size one if not specified. """ # map over output chunks + axis = tuple(ax for ax in split_every.keys()) + combine_sizes = combine_sizes or {} + combine_sizes = {k: combine_sizes.get(k, 1) for k in axis} chunks = [ - (1,) * math.ceil(len(c) / split_every[i]) if i in split_every else c + (combine_sizes[i],) * math.ceil(len(c) / split_every[i]) + if i in split_every + else c for (i, c) in enumerate(x.chunks) ] shape = tuple(map(sum, chunks)) - axis = tuple(ax for ax in split_every.keys()) def key_function(out_key): out_coords = out_key[1:] diff --git a/cubed/tests/test_groupby.py b/cubed/tests/test_groupby.py new file mode 100644 index 00000000..16e2b1a4 --- /dev/null +++ b/cubed/tests/test_groupby.py @@ -0,0 +1,61 @@ +import numpy as np +import numpy_groupies as npg +from numpy.testing import assert_array_equal + +import cubed.array_api as xp +from cubed.backend_array_api import namespace as nxp +from cubed.core.groupby import groupby_reduction + + +def test_groupby_reduction_axis0(): + a = xp.full((4 * 6, 5), 7, dtype=nxp.int32, chunks=(4, 2)) + b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,)) + c = mean_groupby_reduction(a, b, axis=0, num_groups=2) + assert_array_equal(c.compute(), np.full((2, 5), 7)) + + +def test_groupby_reduction_axis1(): + a = xp.full((5, 4 * 6), 7, dtype=nxp.int32, chunks=(2, 4)) + b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,)) + c = mean_groupby_reduction(a, b, axis=1, num_groups=2) + assert_array_equal(c.compute(), np.full((5, 2), 7)) + + +def mean_groupby_reduction(x, by, axis, num_groups): + intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] + dtype = x.dtype + + return groupby_reduction( + x, + by, + func=_mean_groupby_func, + combine_func=_mean_groupby_combine, + aggregate_func=_mean_groupby_aggregate, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + num_groups=num_groups, + ) + + +def _mean_groupby_func(a, by, axis, intermediate_dtype, num_groups): + dtype = dict(intermediate_dtype) + n = npg.aggregate(by, a, func="len", dtype=dtype["n"], axis=axis, size=num_groups) + total = npg.aggregate( + by, a, func="sum", dtype=dtype["total"], axis=axis, size=num_groups + ) + return {"n": n, "total": total} + + +def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims): + # only combine over the dummy axis, to preserve grouping along 'axis' + dtype = dict(dtype) + n = nxp.sum(a["n"], dtype=dtype["n"], axis=dummy_axis, keepdims=keepdims) + total = nxp.sum( + a["total"], dtype=dtype["total"], axis=dummy_axis, keepdims=keepdims + ) + return {"n": n, "total": total} + + +def _mean_groupby_aggregate(a): + return nxp.divide(a["total"], a["n"]) diff --git a/pyproject.toml b/pyproject.toml index 5bd08244..5c639704 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ test = [ "apache-beam", # beam but not gcsfs as tests use local beam runner "cubed[diagnostics,lithops]", # modal tests separate due to conflicting package reqs "dill", + "numpy_groupies", "pytest", "pytest-cov", "pytest-mock", @@ -76,6 +77,7 @@ test = [ test-dask = [ "cubed[dask,diagnostics]", "dill", + "numpy_groupies", "pytest", "pytest-cov", "pytest-mock", @@ -83,12 +85,14 @@ test-dask = [ test-dask-distributed = [ "cubed[dask-distributed,diagnostics]", "dill", + "numpy_groupies", "pytest", "pytest-cov", "pytest-mock", ] test-modal = [ "cubed[modal]", + "numpy_groupies", "dill", "pytest", "pytest-cov",