Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bincount fix slicing #7391

Merged
merged 10 commits into from
Mar 18, 2021
29 changes: 24 additions & 5 deletions dask/array/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
from ..utils import funcname, derived_from, is_arraylike
from . import chunk
from .creation import arange, diag, empty, indices, tri, zeros
from .utils import safe_wraps, validate_axis, meta_from_array, zeros_like_safe
from .utils import (
safe_wraps,
validate_axis,
meta_from_array,
zeros_like_safe,
array_safe,
)
from .wrap import ones
from .ufunc import multiply, sqrt

Expand Down Expand Up @@ -639,29 +645,42 @@ def bincount(x, weights=None, minlength=0, split_every=None):
if weights.chunks != x.chunks:
raise ValueError("Chunks of input array x and weights must match.")

axis = (0,)
token = tokenize(x, weights, minlength)
args = [x, "i"]
if weights is not None:
meta = np.bincount([1], weights=[1])
meta = array_safe(np.bincount([1], weights=[1]), x._meta)
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
args.extend([weights, "i"])
else:
meta = np.bincount([])
meta = array_safe(np.bincount([]), x._meta)
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved

if minlength == 0:
output_size = np.nan
else:
output_size = minlength

chunked_counts = blockwise(
partial(np.bincount, minlength=minlength), "i", *args, token=token, meta=meta
)
chunked_counts._chunks = tuple(
(output_size,) * len(c) if i in axis else c
for i, c in enumerate(chunked_counts.chunks)
)

from .reductions import _tree_reduce

output = _tree_reduce(
chunked_counts,
aggregate=partial(_bincount_agg, dtype=meta.dtype),
axis=(0,),
keepdims=False,
axis=axis,
keepdims=True,
dtype=meta.dtype,
split_every=split_every,
concatenate=False,
)
output._chunks = tuple(
(output_size,) if i in axis else c for i, c in enumerate(chunked_counts.chunks)
)
output._meta = meta
return output

Expand Down
1 change: 1 addition & 0 deletions dask/array/tests/test_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,7 @@ def test_cupy_sparse_concatenate(axis):
assert (z.toarray() == z_expected.toarray()).all()


@pytest.mark.skipif(not _numpy_120, reason="NEP-35 is not available")
@pytest.mark.skipif(
not IS_NEP18_ACTIVE or cupy.__version__ < LooseVersion("6.4.0"),
reason="NEP-18 support is not available in NumPy or CuPy older than "
Expand Down
5 changes: 5 additions & 0 deletions dask/array/tests/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,15 @@ def test_bincount():
e = da.bincount(d, minlength=6)
assert_eq(e, np.bincount(x, minlength=6))
assert same_keys(da.bincount(d, minlength=6), e)
assert e.shape == (6,) # shape equal to minlength
assert e.chunks == ((6,),)

assert da.bincount(d, minlength=6).name != da.bincount(d, minlength=7).name
assert da.bincount(d, minlength=6).name == da.bincount(d, minlength=6).name

expected_output = np.array([0, 2, 2, 0, 0, 1])
assert_eq(e[0:].compute(), expected_output) # can bincount result be sliced
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"weights",
Expand Down