Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove copyreg workaround, use attrs to store reference to original array #949

Merged
merged 9 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ classifiers = [
"Topic :: Utilities",
]
dependencies = [
"awkward>=2.5.0",
"uproot>=5.1.2",
"awkward>=2.5.1rc1",
"uproot>=5.2.0rc3",
"dask[array]>=2023.4.0",
"dask-awkward>=2023.11.5",
"dask-awkward>=2023.12.0",
"dask-histogram>=2023.10.0",
"correctionlib>=2.3.3",
"pyarrow>=6.0.0",
Expand Down
7 changes: 0 additions & 7 deletions src/coffea/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,4 @@
# control severity for utils.deprecate
deprecations_as_errors = False

import copyreg

import dask_awkward

copyreg.pickle(dask_awkward.Array, lambda x: (lambda y: y, (None,)))


__all__ = ["deprecations_as_errors"]
7 changes: 3 additions & 4 deletions src/coffea/nanoevents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,21 +664,20 @@ def events(self):
report = None
if isinstance(events, tuple):
events, report = events
events.behavior["__original_array__"] = lambda: events
events.attrs["@original_array"] = events
if report is not None:
return events, report
return events

events = self._events()
if events is None:
behavior = dict(self._schema.behavior())
behavior["__events_factory__"] = self
events = awkward.from_buffers(
self._schema.form,
len(self),
self._mapping,
buffer_key=partial(_key_formatter, self._partition_key),
behavior=behavior,
behavior=self._schema.behavior(),
attrs={"@events_factory": self},
)
self._events = weakref.ref(events)

Expand Down
21 changes: 14 additions & 7 deletions src/coffea/nanoevents/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import awkward
import dask_awkward
import numpy
from dask_awkward import dask_method, dask_property

import coffea
from coffea.util import awkward_rewrap, rewrap_recordarray
Expand Down Expand Up @@ -160,7 +161,7 @@ class NanoEvents(Systematic):
This mixin class is used as the top-level type for NanoEvents objects.
"""

@dask_awkward.dask_property(no_dispatch=True)
@dask_property(no_dispatch=True)
def metadata(self):
"""Arbitrary metadata"""
return self.layout.purelist_parameter("metadata")
Expand All @@ -177,10 +178,12 @@ class NanoCollection:
and other advanced mixin types.
"""

@dask_method(no_dispatch=True)
def _collection_name(self):
"""The name of the collection (i.e. the field under events where it is found)"""
return self.layout.purelist_parameter("collection_name")

@dask_method(no_dispatch=True)
def _getlistarray(self):
"""Do some digging to find the initial listarray"""

Expand All @@ -194,14 +197,15 @@ def descend(layout, depth, **kwargs):

return awkward.transform(descend, self.layout, highlevel=False)

@dask_method(no_dispatch=True)
def _content(self):
"""Internal method to get jagged collection content

This should only be called on the original unsliced collection array.
Used with global indexes to resolve cross-references"""
return self._getlistarray().content

@dask_awkward.dask_method
@dask_method
def _apply_global_index(self, index):
"""Internal method to take from a collection using a flat index

Expand All @@ -224,7 +228,7 @@ def descend(layout, depth, **kwargs):
index._meta if isinstance(index, dask_awkward.Array) else index
)
layout_out = awkward.transform(descend, index_out.layout, highlevel=False)
out = awkward.Array(layout_out, behavior=self.behavior)
out = awkward.Array(layout_out, behavior=self.behavior, attrs=self.attrs)

return out

Expand All @@ -236,14 +240,17 @@ def _apply_global_index(self, dask_array, index):
label="_apply_global_index",
)

@dask_method(no_dispatch=True)
def _events(self):
"""Internal method to get the originally-constructed NanoEvents

This can be called at any time from any collection, as long as
the NanoEventsFactory instance exists."""
if "__original_array__" in self.behavior:
return self.behavior["__original_array__"]()
return self.behavior["__events_factory__"].events()
the NanoEventsFactory instance exists.

This will not work automatically if you read serialized nanoevents."""
if "@original_array" in self.attrs:
return self.attrs["@original_array"]
return self.attrs["@events_factory"].events()


__all__ = ["NanoCollection", "NanoEvents", "Systematic"]
Loading
Loading