From 22f174d31b5d8a556f7db3770795db0a6fc00b08 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 9 Apr 2024 09:58:08 +0100 Subject: [PATCH] Group by --- cubed/array_api/manipulation_functions.py | 8 +- cubed/core/groupby.py | 89 +++++++++++++++++++++++ cubed/core/ops.py | 38 +++++++--- cubed/tests/test_groupby.py | 61 ++++++++++++++++ pyproject.toml | 1 + 5 files changed, 185 insertions(+), 12 deletions(-) create mode 100644 cubed/core/groupby.py create mode 100644 cubed/tests/test_groupby.py diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index c085437f0..9dbb2417a 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -151,10 +151,16 @@ def expand_dims(x, /, *, axis): chunks = tuple(1 if i in axis else next(chunks_it) for i in range(ndim_new)) return map_blocks( - nxp.expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis + _expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis ) +def _expand_dims(a, *args, **kwargs): + if isinstance(a, dict): + return {k: nxp.expand_dims(v, *args, **kwargs) for k, v in a.items()} + return nxp.expand_dims(a, *args, **kwargs) + + def flatten(x): return reshape(x, (-1,)) diff --git a/cubed/core/groupby.py b/cubed/core/groupby.py new file mode 100644 index 000000000..98696f56d --- /dev/null +++ b/cubed/core/groupby.py @@ -0,0 +1,89 @@ +from typing import TYPE_CHECKING + +from cubed.array_api.manipulation_functions import broadcast_to, expand_dims +from cubed.backend_array_api import namespace as nxp +from cubed.core.ops import map_blocks, reduction_new + +if TYPE_CHECKING: + from cubed.array_api.array_object import Array + + +def groupby_reduction( + x: "Array", + by: "Array", + func, + combine_func=None, + aggegrate_func=None, + axis=None, + intermediate_dtype=None, + dtype=None, + keepdims=False, + split_every=None, + num_groups=None, + extra_func_kwargs=None, +) -> "Array": + """A reduction that performs groupby aggregations.""" + + if isinstance(axis, tuple): + if len(axis) != 1: + raise ValueError( + f"Only a single axis is supported for groupby_reduction: {axis}" + ) + axis = axis[0] + + # make sure 'by' has corresponding blocks to 'x' + for ax in range(x.ndim): + if ax != axis: + by = expand_dims(by, axis=ax) + by_chunks = tuple( + c if i == axis else (1,) * x.numblocks[i] for i, c in enumerate(by.chunks) + ) + by_shape = tuple(map(sum, by_chunks)) + by = broadcast_to(by, by_shape, chunks=by_chunks) + + # wrapper to squeeze 'by' to undo effect of broadcast, so it looks same + # to user supplied func + def _group_reduction_func_wrapper(func): + def wrapper(a, by, **kwargs): + return func(a, nxp.squeeze(by), **kwargs) + + return wrapper + + # initial map does group reduction on each block + chunks = tuple( + (num_groups,) * len(c) if i == axis else c for i, c in enumerate(x.chunks) + ) + out = map_blocks( + _group_reduction_func_wrapper(func), + x, + by, + dtype=intermediate_dtype, + chunks=chunks, + axis=axis, + intermediate_dtype=intermediate_dtype, + num_groups=num_groups, + ) + + # add a dummy dimension to reduce over + dummy_axis = -1 + out = expand_dims(out, axis=dummy_axis) + + # then reduce across blocks + return reduction_new( + out, + func=_identity_func, + combine_func=combine_func, + aggegrate_func=aggegrate_func, + axis=(dummy_axis, axis), # dummy and group axis + intermediate_dtype=intermediate_dtype, + dtype=dtype, + keepdims=keepdims, + split_every=split_every, + combine_sizes={axis: num_groups}, # group axis doesn't have size 1 + extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis), + ) + + +def _identity_func(a, **kwargs): + # pass through + return a diff --git a/cubed/core/ops.py b/cubed/core/ops.py index deba169c5..0d10f66ea 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -900,14 +900,14 @@ def reduction( return reduction_new( x, func, - combine_func, - aggegrate_func, - axis, - intermediate_dtype, - dtype, - keepdims, - split_every, - extra_func_kwargs, + combine_func=combine_func, + aggegrate_func=aggegrate_func, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + keepdims=keepdims, + split_every=split_every, + extra_func_kwargs=extra_func_kwargs, ) if combine_func is None: combine_func = func @@ -1008,6 +1008,7 @@ def reduction_new( dtype=None, keepdims=False, split_every=None, + combine_sizes=None, extra_func_kwargs=None, ) -> "Array": """Apply a function to reduce an array along one or more axes.""" @@ -1031,6 +1032,7 @@ def reduction_new( ), split_every=split_every, dtype=intermediate_dtype, + combine_sizes=combine_sizes, ) # combine intermediates @@ -1040,6 +1042,7 @@ def reduction_new( axis=axis, dtype=intermediate_dtype, split_every=split_every, + combine_sizes=combine_sizes, ) # aggregate final chunks @@ -1076,6 +1079,7 @@ def tree_reduce( axis, dtype, split_every=None, + combine_sizes=None, ): """Apply a reduction function repeatedly across multiple axes.""" if axis is None: @@ -1096,11 +1100,19 @@ def tree_reduce( func, split_every=split_every, dtype=dtype, + combine_sizes=combine_sizes, ) return x -def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None): +def partial_reduce( + x, + func, + initial_func=None, + split_every=None, + dtype=None, + combine_sizes=None, +): """Apply a reduction function to multiple blocks across multiple axes. Parameters @@ -1118,12 +1130,16 @@ def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None): Output data type. """ # map over output chunks + axis = tuple(ax for ax in split_every.keys()) + combine_sizes = combine_sizes or {} + combine_sizes = {k: combine_sizes.get(k, 1) for k in axis} chunks = [ - (1,) * math.ceil(len(c) / split_every[i]) if i in split_every else c + (combine_sizes[i],) * math.ceil(len(c) / split_every[i]) + if i in split_every + else c for (i, c) in enumerate(x.chunks) ] shape = tuple(map(sum, chunks)) - axis = tuple(ax for ax in split_every.keys()) def key_function(out_key): out_coords = out_key[1:] diff --git a/cubed/tests/test_groupby.py b/cubed/tests/test_groupby.py new file mode 100644 index 000000000..c7a855c13 --- /dev/null +++ b/cubed/tests/test_groupby.py @@ -0,0 +1,61 @@ +import numpy as np +import numpy_groupies as npg +from numpy.testing import assert_array_equal + +import cubed.array_api as xp +from cubed.backend_array_api import namespace as nxp +from cubed.core.groupby import groupby_reduction + + +def test_groupby_reduction_axis0(): + a = xp.full((4 * 6, 5), 7, dtype=nxp.int32, chunks=(4, 2)) + b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,)) + c = mean_groupby_reduction(a, b, axis=0, num_groups=2) + assert_array_equal(c.compute(), np.full((2, 5), 7)) + + +def test_groupby_reduction_axis1(): + a = xp.full((5, 4 * 6), 7, dtype=nxp.int32, chunks=(2, 4)) + b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,)) + c = mean_groupby_reduction(a, b, axis=1, num_groups=2) + assert_array_equal(c.compute(), np.full((5, 2), 7)) + + +def mean_groupby_reduction(x, by, axis, num_groups): + intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] + dtype = x.dtype + + return groupby_reduction( + x, + by, + func=_mean_groupby_func, + combine_func=_mean_groupby_combine, + aggegrate_func=_mean_groupby_aggregate, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + num_groups=num_groups, + ) + + +def _mean_groupby_func(a, by, axis, intermediate_dtype, num_groups): + dtype = dict(intermediate_dtype) # TODO: shouldn't have to do this + n = npg.aggregate(by, a, func="len", dtype=dtype["n"], axis=axis, size=num_groups) + total = npg.aggregate( + by, a, func="sum", dtype=dtype["total"], axis=axis, size=num_groups + ) + return {"n": n, "total": total} + + +def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims): + # only combine over the dummy axis, to preserve grouping along 'axis' + dtype = dict(dtype) # TODO: shouldn't have to do this + n = nxp.sum(a["n"], dtype=dtype["n"], axis=dummy_axis, keepdims=keepdims) + total = nxp.sum( + a["total"], dtype=dtype["total"], axis=dummy_axis, keepdims=keepdims + ) + return {"n": n, "total": total} + + +def _mean_groupby_aggregate(a): + return nxp.divide(a["total"], a["n"]) diff --git a/pyproject.toml b/pyproject.toml index 78457a240..2acff9897 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ test = [ "apache-beam", # beam but not gcsfs as tests use local beam runner "cubed[diagnostics,lithops]", # modal tests separate due to conflicting package reqs "dill", + "numpy_groupies", "pytest", "pytest-cov", "pytest-mock",