Skip to content

Commit

Permalink
daskify PackedSelection
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 15, 2023
1 parent f4e7505 commit e4b398f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 28 deletions.
38 changes: 31 additions & 7 deletions coffea/analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
These helper classes were previously part of ``coffea.processor``
but have been migrated and updated to be compatible with awkward-array 1.0
"""
import awkward
import dask.array
import dask_awkward
import numpy

import coffea.processor
Expand Down Expand Up @@ -325,21 +328,38 @@ def add(self, name, selection, fill_value=False):
----------
name : str
name of the selection
selection : numpy.ndarray or awkward.Array
selection : numpy.ndarray, awkward.Array, dask.array.Array, or dask_awkward.Array
a flat array of type ``bool`` or ``?bool``.
If this is not the first selection added, it must also have
the same shape as previously added selections. If the array
is option-type, null entries will be filled with ``fill_value``.
fill_value : bool, optional
All masked entries will be filled as specified (default: ``False``)
"""
selection = coffea.util._ensure_flat(selection, allow_missing=True)
if isinstance(selection, numpy.ma.MaskedArray):
selection = selection.filled(fill_value)
array_lib = numpy
if isinstance(selection, (dask.array.Array, dask_awkward.Array)):
array_lib = dask.array
selection = (
selection
if isinstance(selection, dask.array.Array)
else dask_awkward.to_dask_array(selection)
)
if isinstance(selection._meta, numpy.ma.MaskedArray):
selection = dask.array.ma.filled(selection, fill_value)
elif isinstance(selection, (numpy.ndarray, awkward.Array)):
selection = coffea.util._ensure_flat(selection, allow_missing=True)
if isinstance(selection, numpy.ma.MaskedArray):
selection = selection.filled(fill_value)
else:
raise TypeError(
"selection is not a numpy.ndarray, awkward.Array, dask.array.Array, or dask_awkward.Array"
)
if selection.dtype != bool:
raise ValueError(f"Expected a boolean array, received {selection.dtype}")
raise ValueError(
f"Expected a boolean dask array, received {selection.dtype}"
)
if len(self._names) == 0:
self._data = numpy.zeros(len(selection), dtype=self._dtype)
self._data = array_lib.zeros(len(selection), dtype=self._dtype)
elif len(self._names) == self.maxitems:
raise RuntimeError(
f"Exhausted all slots in {self}, consider a larger dtype or fewer selections"
Expand All @@ -348,7 +368,7 @@ def add(self, name, selection, fill_value=False):
raise ValueError(
f"New selection '{name}' has a different shape than existing selections ({selection.shape} vs. {self._data.shape})"
)
numpy.bitwise_or(
array_lib.bitwise_or(
self._data,
self._dtype.type(1 << len(self._names)),
where=selection,
Expand Down Expand Up @@ -389,6 +409,8 @@ def require(self, **names):
idx = self._names.index(name)
consider |= 1 << idx
require |= int(val) << idx
if isinstance(self._data, dask.array.Array):
return dask_awkward.from_dask_array((self._data & consider) == require)
return (self._data & consider) == require

def all(self, *names):
Expand Down Expand Up @@ -422,4 +444,6 @@ def any(self, *names):
for name in names:
idx = self._names.index(name)
consider |= 1 << idx
if isinstance(self._data, dask.array.Array):
return dask_awkward.from_dask_array((self._data & consider) != 0)
return (self._data & consider) != 0
69 changes: 48 additions & 21 deletions tests/test_analysis_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask.array as da
import numpy as np
import pytest
from dummy_distributions import dummy_jagged_eta_pt
Expand Down Expand Up @@ -160,39 +161,65 @@ def test_weights_partial():
assert error_raised


def test_packed_selection():
@pytest.mark.parametrize("array_lib", [np, da])
def test_packed_selection(array_lib):
from coffea.analysis_tools import PackedSelection

sel = PackedSelection()

shape = (10,)
all_true = np.full(shape=shape, fill_value=True, dtype=bool)
all_false = np.full(shape=shape, fill_value=False, dtype=bool)
fizz = np.arange(shape[0]) % 3 == 0
buzz = np.arange(shape[0]) % 5 == 0
ones = np.ones(shape=shape, dtype=np.uint64)
wrong_shape = ones = np.ones(shape=(shape[0] - 5,), dtype=bool)
all_true = array_lib.full(shape=shape, fill_value=True, dtype=bool)
all_false = array_lib.full(shape=shape, fill_value=False, dtype=bool)
fizz = array_lib.arange(shape[0]) % 3 == 0
buzz = array_lib.arange(shape[0]) % 5 == 0
ones = array_lib.ones(shape=shape, dtype=np.uint64)
wrong_shape = ones = array_lib.ones(shape=(shape[0] - 5,), dtype=bool)

sel.add("all_true", all_true)
sel.add("all_false", all_false)
sel.add("fizz", fizz)
sel.add("buzz", buzz)

assert np.all(sel.require(all_true=True, all_false=False) == all_true)
# allow truthy values
assert np.all(sel.require(all_true=1, all_false=0) == all_true)
assert np.all(sel.all("all_true", "all_false") == all_false)
assert np.all(sel.any("all_true", "all_false") == all_true)
assert np.all(
sel.all("fizz", "buzz")
== np.array(
[True, False, False, False, False, False, False, False, False, False]
if array_lib == np:
assert np.all(sel.require(all_true=True, all_false=False) == all_true)
# allow truthy values
assert np.all(sel.require(all_true=1, all_false=0) == all_true)
assert np.all(sel.all("all_true", "all_false") == all_false)
assert np.all(sel.any("all_true", "all_false") == all_true)
assert np.all(
sel.all("fizz", "buzz")
== np.array(
[True, False, False, False, False, False, False, False, False, False]
)
)
assert np.all(
sel.any("fizz", "buzz")
== np.array(
[True, False, False, True, False, True, True, False, False, True]
)
)
else:
assert np.all(
sel.require(all_true=True, all_false=False).compute() == all_true.compute()
)
# allow truthy values
assert np.all(
sel.require(all_true=1, all_false=0).compute() == all_true.compute()
)
assert np.all(sel.all("all_true", "all_false").compute() == all_false.compute())
assert np.all(sel.any("all_true", "all_false").compute() == all_true.compute())
assert np.all(
sel.all("fizz", "buzz").compute()
== np.array(
[True, False, False, False, False, False, False, False, False, False]
)
)
assert np.all(
sel.any("fizz", "buzz").compute()
== np.array(
[True, False, False, True, False, True, True, False, False, True]
)
)
)
assert np.all(
sel.any("fizz", "buzz")
== np.array([True, False, False, True, False, True, True, False, False, True])
)

with pytest.raises(ValueError):
sel.add("wrong_shape", wrong_shape)
Expand Down

0 comments on commit e4b398f

Please sign in to comment.