Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group by #442

Merged
merged 5 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,16 @@ def expand_dims(x, /, *, axis):
chunks = tuple(1 if i in axis else next(chunks_it) for i in range(ndim_new))

return map_blocks(
nxp.expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis
_expand_dims, x, dtype=x.dtype, chunks=chunks, new_axis=axis, axis=axis
)


def _expand_dims(a, *args, **kwargs):
if isinstance(a, dict):
return {k: nxp.expand_dims(v, *args, **kwargs) for k, v in a.items()}
return nxp.expand_dims(a, *args, **kwargs)


def flatten(x):
return reshape(x, (-1,))

Expand Down
118 changes: 118 additions & 0 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import TYPE_CHECKING

from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_blocks, reduction_new

if TYPE_CHECKING:
from cubed.array_api.array_object import Array


def groupby_reduction(
x: "Array",
by: "Array",
func,
combine_func=None,
aggregate_func=None,
axis=None,
intermediate_dtype=None,
dtype=None,
keepdims=False,
split_every=None,
num_groups=None,
extra_func_kwargs=None,
) -> "Array":
"""A reduction that performs groupby aggregations.

Parameters
----------
x: Array
Array being grouped along one axis.
by: Array
Array of non-negative integers to be used as labels with which to group
the values in ``x`` along the reduction axis. Must be a 1D array.
func: callable
Function to apply to each chunk of data before reduction.
combine_func: callable
Function which may be applied recursively to intermediate chunks of
data. The number of chunks that are combined in each round is
determined by the ``split_every`` parameter. The output of the
function is a chunk with size ``num_groups`` along the reduction axis.
aggregate_func: callable, optional
Function to apply to each of the final chunks to produce the final output.
axis: int or sequence of ints, optional
Axis to aggregate along. Only supports a single axis.
intermediate_dtype: dtype
Data type of intermediate output.
dtype: dtype
Data type of output.
keepdims: boolean, optional
Whether the reduction function should preserve the reduced axes,
or remove them.
split_every: int >= 2 or dict(axis: int), optional
The number of chunks to combine in one round along each axis in the
recursive aggregation.
num_groups: int
The number of groups in the grouping array ``by``.
extra_func_kwargs: dict, optional
Extra keyword arguments to pass to ``func`` and ``combine_func``.
"""

if isinstance(axis, tuple):
if len(axis) != 1:
raise ValueError(
f"Only a single axis is supported for groupby_reduction: {axis}"
)
axis = axis[0]

# make sure 'by' has corresponding blocks to 'x'
for ax in range(x.ndim):
if ax != axis:
by = expand_dims(by, axis=ax)
by_chunks = tuple(
c if i == axis else (1,) * x.numblocks[i] for i, c in enumerate(by.chunks)
)
by_shape = tuple(map(sum, by_chunks))
by = broadcast_to(by, by_shape, chunks=by_chunks)

# wrapper to squeeze 'by' to undo effect of broadcast, so it looks same
# to user supplied func
def _group_reduction_func_wrapper(func):
def wrapper(a, by, **kwargs):
return func(a, nxp.squeeze(by), **kwargs)

return wrapper

# initial map does group reduction on each block
chunks = tuple(
(num_groups,) * len(c) if i == axis else c for i, c in enumerate(x.chunks)
)
out = map_blocks(
_group_reduction_func_wrapper(func),
x,
by,
dtype=intermediate_dtype,
chunks=chunks,
axis=axis,
intermediate_dtype=intermediate_dtype,
num_groups=num_groups,
)

# add a dummy dimension to reduce over
dummy_axis = -1
out = expand_dims(out, axis=dummy_axis)

# then reduce across blocks
return reduction_new(
out,
func=None,
combine_func=combine_func,
aggregate_func=aggregate_func,
axis=(dummy_axis, axis), # dummy and group axis
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
split_every=split_every,
combine_sizes={axis: num_groups}, # group axis doesn't have size 1
extra_func_kwargs=dict(dtype=intermediate_dtype, dummy_axis=dummy_axis),
)
99 changes: 81 additions & 18 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,14 +909,14 @@ def reduction(
return reduction_new(
x,
func,
combine_func,
aggregate_func,
axis,
intermediate_dtype,
dtype,
keepdims,
split_every,
extra_func_kwargs,
combine_func=combine_func,
aggregate_func=aggregate_func,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
if combine_func is None:
combine_func = func
Expand Down Expand Up @@ -1017,10 +1017,49 @@ def reduction_new(
dtype=None,
keepdims=False,
split_every=None,
combine_sizes=None,
extra_func_kwargs=None,
) -> "Array":
"""Apply a function to reduce an array along one or more axes."""
"""Apply a function to reduce an array along one or more axes.

Parameters
----------
x: Array
Array being reduced along one or more axes.
func: callable
Function to apply to each chunk of data before reduction.
combine_func: callable, optional
Function which may be applied recursively to intermediate chunks of
data. The number of chunks that are combined in each round is
determined by the ``split_every`` parameter. The output of the
function is a chunk with size one (or the size specified in
``combine_sizes``) in each of the reduction axes. If omitted,
it defaults to ``func``.
aggregate_func: callable, optional
Function to apply to each of the final chunks to produce the final output.
axis: int or sequence of ints, optional
Axis or axes to aggregate upon. If omitted, aggregate along all axes.
intermediate_dtype: dtype
Data type of intermediate output.
dtype: dtype
Data type of output.
keepdims: boolean, optional
Whether the reduction function should preserve the reduced axes,
or remove them.
split_every: int >= 2 or dict(axis: int), optional
The number of chunks to combine in one round along each axis in the
recursive aggregation.
combine_sizes: dict(axis: int), optional
The resulting size of each axis after reduction. Each reduction axis
defaults to size one if not specified.
extra_func_kwargs: dict, optional
Extra keyword arguments to pass to ``func`` and ``combine_func``.
"""
if combine_func is None:
if func is None:
raise ValueError(
"At least one of `func` and `combine_func` must be specified in reduction"
)
combine_func = func
if axis is None:
axis = tuple(range(x.ndim))
Expand All @@ -1032,14 +1071,19 @@ def reduction_new(

split_every = _normalize_split_every(split_every, axis)

if func is None:
initial_func = None
else:
initial_func = partial(
func, axis=axis, keepdims=True, **(extra_func_kwargs or {})
)
result = partial_reduce(
x,
partial(combine_func, **(extra_func_kwargs or {})),
initial_func=partial(
func, axis=axis, keepdims=True, **(extra_func_kwargs or {})
),
initial_func=initial_func,
split_every=split_every,
dtype=intermediate_dtype,
combine_sizes=combine_sizes,
)

# combine intermediates
Expand All @@ -1049,6 +1093,7 @@ def reduction_new(
axis=axis,
dtype=intermediate_dtype,
split_every=split_every,
combine_sizes=combine_sizes,
)

# aggregate final chunks
Expand Down Expand Up @@ -1085,6 +1130,7 @@ def tree_reduce(
axis,
dtype,
split_every=None,
combine_sizes=None,
):
"""Apply a reduction function repeatedly across multiple axes."""
if axis is None:
Expand All @@ -1105,11 +1151,19 @@ def tree_reduce(
func,
split_every=split_every,
dtype=dtype,
combine_sizes=combine_sizes,
)
return x


def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None):
def partial_reduce(
x,
func,
initial_func=None,
split_every=None,
dtype=None,
combine_sizes=None,
):
"""Apply a reduction function to multiple blocks across multiple axes.

Parameters
Expand All @@ -1118,21 +1172,30 @@ def partial_reduce(x, func, initial_func=None, split_every=None, dtype=None):
Array being reduced along one or more axes
func: callable
Reduction function to apply to each chunk of data, resulting in a chunk
with size one in each of the reduction axes.
with size one (or the size specified in ``combine_sizes``) in each of
the reduction axes.
initial_func: callable, optional
Function to apply to each chunk of data before reduction.
split_every: int >= 2 or dict(axis: int), optional
The depth of the recursive aggregation.
dtype: DType
The number of chunks to combine in one round along each axis in the
recursive aggregation.
dtype: dtype
Output data type.
combine_sizes: dict(axis: int), optional
The resulting size of each axis after reduction. Each reduction axis
defaults to size one if not specified.
"""
# map over output chunks
axis = tuple(ax for ax in split_every.keys())
combine_sizes = combine_sizes or {}
combine_sizes = {k: combine_sizes.get(k, 1) for k in axis}
chunks = [
(1,) * math.ceil(len(c) / split_every[i]) if i in split_every else c
(combine_sizes[i],) * math.ceil(len(c) / split_every[i])
if i in split_every
else c
for (i, c) in enumerate(x.chunks)
]
shape = tuple(map(sum, chunks))
axis = tuple(ax for ax in split_every.keys())

def key_function(out_key):
out_coords = out_key[1:]
Expand Down
61 changes: 61 additions & 0 deletions cubed/tests/test_groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
import numpy_groupies as npg
from numpy.testing import assert_array_equal

import cubed.array_api as xp
from cubed.backend_array_api import namespace as nxp
from cubed.core.groupby import groupby_reduction


def test_groupby_reduction_axis0():
a = xp.full((4 * 6, 5), 7, dtype=nxp.int32, chunks=(4, 2))
b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,))
c = mean_groupby_reduction(a, b, axis=0, num_groups=2)
assert_array_equal(c.compute(), np.full((2, 5), 7))


def test_groupby_reduction_axis1():
a = xp.full((5, 4 * 6), 7, dtype=nxp.int32, chunks=(2, 4))
b = xp.asarray([0, 1, 0, 1] * 6, chunks=(4,))
c = mean_groupby_reduction(a, b, axis=1, num_groups=2)
assert_array_equal(c.compute(), np.full((5, 2), 7))


def mean_groupby_reduction(x, by, axis, num_groups):
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
dtype = x.dtype

return groupby_reduction(
x,
by,
func=_mean_groupby_func,
combine_func=_mean_groupby_combine,
aggregate_func=_mean_groupby_aggregate,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
num_groups=num_groups,
)


def _mean_groupby_func(a, by, axis, intermediate_dtype, num_groups):
dtype = dict(intermediate_dtype)
n = npg.aggregate(by, a, func="len", dtype=dtype["n"], axis=axis, size=num_groups)
total = npg.aggregate(
by, a, func="sum", dtype=dtype["total"], axis=axis, size=num_groups
)
return {"n": n, "total": total}


def _mean_groupby_combine(a, axis, dummy_axis, dtype, keepdims):
# only combine over the dummy axis, to preserve grouping along 'axis'
dtype = dict(dtype)
n = nxp.sum(a["n"], dtype=dtype["n"], axis=dummy_axis, keepdims=keepdims)
total = nxp.sum(
a["total"], dtype=dtype["total"], axis=dummy_axis, keepdims=keepdims
)
return {"n": n, "total": total}


def _mean_groupby_aggregate(a):
return nxp.divide(a["total"], a["n"])
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,30 @@ test = [
"apache-beam", # beam but not gcsfs as tests use local beam runner
"cubed[diagnostics,lithops]", # modal tests separate due to conflicting package reqs
"dill",
"numpy_groupies",
"pytest",
"pytest-cov",
"pytest-mock",
]
test-dask = [
"cubed[dask,diagnostics]",
"dill",
"numpy_groupies",
"pytest",
"pytest-cov",
"pytest-mock",
]
test-dask-distributed = [
"cubed[dask-distributed,diagnostics]",
"dill",
"numpy_groupies",
"pytest",
"pytest-cov",
"pytest-mock",
]
test-modal = [
"cubed[modal]",
"numpy_groupies",
"dill",
"pytest",
"pytest-cov",
Expand Down
Loading