Skip to content

Commit

Permalink
Merge pull request #951 from CoffeaTeam/apply-global-index-fix
Browse files Browse the repository at this point in the history
fix: dispatch of _apply_global_index wasn't re-implemented correctly
  • Loading branch information
lgray authored Dec 2, 2023
2 parents 09e965b + d86212d commit 5079569
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 77 deletions.
18 changes: 10 additions & 8 deletions src/coffea/nanoevents/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def _content(self):
Used with global indexes to resolve cross-references"""
return self._getlistarray().content

def _apply_global_index(self, index, _dask_array_=None):
@dask_awkward.dask_method
def _apply_global_index(self, index):
"""Internal method to take from a collection using a flat index
This is often necessary to be able to still resolve cross-references on
Expand All @@ -225,15 +226,16 @@ def descend(layout, depth, **kwargs):
layout_out = awkward.transform(descend, index_out.layout, highlevel=False)
out = awkward.Array(layout_out, behavior=self.behavior)

if isinstance(index, dask_awkward.Array):
return _dask_array_.map_partitions(
_ClassMethodFn("_apply_global_index"),
index,
label="_apply_global_index",
meta=out,
)
return out

@_apply_global_index.dask
def _apply_global_index(self, dask_array, index):
return dask_array.map_partitions(
_ClassMethodFn("_apply_global_index"),
index,
label="_apply_global_index",
)

def _events(self):
"""Internal method to get the originally-constructed NanoEvents
Expand Down
98 changes: 29 additions & 69 deletions src/coffea/nanoevents/methods/nanoaod.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def parent(self):
@parent.dask
def parent(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.genPartIdxMotherG, _dask_array_=original
)
return original._apply_global_index(dask_array.genPartIdxMotherG)

@dask_property
def distinctParent(self):
Expand All @@ -103,9 +101,7 @@ def distinctParent(self):
@distinctParent.dask
def distinctParent(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.distinctParentIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.distinctParentIdxG)

@dask_property
def children(self):
Expand All @@ -114,9 +110,7 @@ def children(self):
@children.dask
def children(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.childrenIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.childrenIdxG)

@dask_property
def distinctChildren(self):
Expand All @@ -125,9 +119,7 @@ def distinctChildren(self):
@distinctChildren.dask
def distinctChildren(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.distinctChildrenIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.distinctChildrenIdxG)

@dask_property
def distinctChildrenDeep(self):
Expand All @@ -144,9 +136,7 @@ def distinctChildrenDeep(self, dask_array):
"distinctChildrenDeep may not give correct answers for all generators!"
)
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.distinctChildrenDeepIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.distinctChildrenDeepIdxG)


_set_repr_name("GenParticle")
Expand All @@ -165,9 +155,7 @@ def parent(self):
def parent(self, dask_array):
"""Accessor to the parent particle"""
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.genPartIdxMotherG, _dask_array_=original
)
return original._apply_global_index(dask_array.genPartIdxMotherG)


_set_repr_name("GenVisTau")
Expand Down Expand Up @@ -216,9 +204,7 @@ def matched_gen(self):
@matched_gen.dask
def matched_gen(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.genPartIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.genPartIdxG)

@dask_property
def matched_jet(self):
Expand All @@ -227,7 +213,7 @@ def matched_jet(self):
@matched_jet.dask
def matched_jet(self, dask_array):
original = dask_array.behavior["__original_array__"]().Jet
return original._apply_global_index(dask_array.jetIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.jetIdxG)

@dask_property
def matched_photon(self):
Expand All @@ -236,9 +222,7 @@ def matched_photon(self):
@matched_photon.dask
def matched_photon(self, dask_array):
original = dask_array.behavior["__original_array__"]().Photon
return original._apply_global_index(
dask_array.photonIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.photonIdxG)


_set_repr_name("Electron")
Expand All @@ -255,9 +239,7 @@ def matched_fsrPhoton(self):
@matched_fsrPhoton.dask
def matched_fsrPhoton(self, dask_array):
original = dask_array.behavior["__original_array__"]().FsrPhoton
return original._apply_global_index(
dask_array.fsrPhotonIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.fsrPhotonIdxG)

@dask_property
def matched_gen(self):
Expand All @@ -266,9 +248,7 @@ def matched_gen(self):
@matched_gen.dask
def matched_gen(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.genPartIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.genPartIdxG)

@dask_property
def matched_jet(self):
Expand All @@ -277,7 +257,7 @@ def matched_jet(self):
@matched_jet.dask
def matched_jet(self, dask_array):
original = dask_array.behavior["__original_array__"]().Jet
return original._apply_global_index(dask_array.jetIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.jetIdxG)


_set_repr_name("Muon")
Expand All @@ -294,9 +274,7 @@ def matched_gen(self):
@matched_gen.dask
def matched_gen(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.genPartIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.genPartIdxG)

@dask_property
def matched_jet(self):
Expand All @@ -305,7 +283,7 @@ def matched_jet(self):
@matched_jet.dask
def matched_jet(self, dask_array):
original = dask_array.behavior["__original_array__"]().Jet
return original._apply_global_index(dask_array.jetIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.jetIdxG)


_set_repr_name("Tau")
Expand Down Expand Up @@ -348,9 +326,7 @@ def matched_electron(self):
@matched_electron.dask
def matched_electron(self, dask_array):
original = dask_array.behavior["__original_array__"]().Electron
return original._apply_global_index(
dask_array.electronIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.electronIdxG)

@dask_property
def matched_gen(self):
Expand All @@ -359,9 +335,7 @@ def matched_gen(self):
@matched_gen.dask
def matched_gen(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenPart
return original._apply_global_index(
dask_array.genPartIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.genPartIdxG)

@dask_property
def matched_jet(self):
Expand All @@ -370,7 +344,7 @@ def matched_jet(self):
@matched_jet.dask
def matched_jet(self, dask_array):
original = dask_array.behavior["__original_array__"]().Jet
return original._apply_global_index(dask_array.jetIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.jetIdxG)


_set_repr_name("Photon")
Expand All @@ -387,7 +361,7 @@ def matched_muon(self):
@matched_muon.dask
def matched_muon(self, dask_array):
original = dask_array.behavior["__original_array__"]().Jet
return original._apply_global_index(dask_array.muonIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.muonIdxG)


_set_repr_name("FsrPhoton")
Expand Down Expand Up @@ -426,9 +400,7 @@ def matched_electrons(self):
@matched_electrons.dask
def matched_electrons(self, dask_array):
original = dask_array.behavior["__original_array__"]().Electron
return original._apply_global_index(
dask_array.electronIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.electronIdxG)

@dask_property
def matched_muons(self):
Expand All @@ -437,7 +409,7 @@ def matched_muons(self):
@matched_muons.dask
def matched_muons(self, dask_array):
original = dask_array.behavior["__original_array__"]().Muon
return original._apply_global_index(dask_array.muonIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.muonIdxG)

@dask_property
def matched_gen(self):
Expand All @@ -446,9 +418,7 @@ def matched_gen(self):
@matched_gen.dask
def matched_gen(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenJet
return original._apply_global_index(
dask_array.genJetIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.genJetIdxG)

@dask_property
def constituents(self):
Expand All @@ -461,9 +431,7 @@ def constituents(self, dask_array):
if "pFCandsIdxG" not in self.fields:
raise RuntimeError("PF candidates are only available for PFNano")
original = dask_array.behavior["__original_array__"]().JetPFCands
return original._apply_global_index(
dask_array.pFCandsIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.pFCandsIdxG)


_set_repr_name("Jet")
Expand Down Expand Up @@ -502,9 +470,7 @@ def subjets(self):
@subjets.dask
def subjets(self, dask_array):
original = dask_array.behavior["__original_array__"]().SubJet
return original._apply_global_index(
dask_array.subJetIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.subJetIdxG)

@dask_property
def matched_gen(self):
Expand All @@ -513,9 +479,7 @@ def matched_gen(self):
@matched_gen.dask
def matched_gen(self, dask_array):
original = dask_array.behavior["__original_array__"]().GenJetAK8
return original._apply_global_index(
dask_array.genJetAK8IdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.genJetAK8IdxG)

@dask_property
def constituents(self):
Expand All @@ -528,9 +492,7 @@ def constituents(self, dask_array):
if "pFCandsIdxG" not in self.fields:
raise RuntimeError("PF candidates are only available for PFNano")
original = dask_array.behavior["__original_array__"]().FatJetPFCands
return original._apply_global_index(
dask_array.pFCandsIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.pFCandsIdxG)


_set_repr_name("FatJet")
Expand Down Expand Up @@ -611,7 +573,7 @@ def jet(self):
def jet(self, dask_array):
collection = self.collection_map[self._collection_name()][0]
original = dask_array.behavior["__original_array__"]()[collection]
return original._apply_global_index(dask_array.jetIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.jetIdxG)

@dask_property
def pf(self):
Expand All @@ -622,9 +584,7 @@ def pf(self):
def pf(self, dask_array):
collection = self.collection_map[self._collection_name()][1]
original = dask_array.behavior["__original_array__"]()[collection]
return original._apply_global_index(
dask_array.pFCandsIdxG, _dask_array_=original
)
return original._apply_global_index(dask_array.pFCandsIdxG)


_set_repr_name("AssociatedPFCand")
Expand All @@ -651,7 +611,7 @@ def jet(self):
def jet(self, dask_array):
collection = self._events()[self.collection_map[self._collection_name()][0]]
original = dask_array.behavior["__original_array__"]()[collection]
return original._apply_global_index(dask_array.jetIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.jetIdxG)

@dask_property
def sv(self):
Expand All @@ -662,7 +622,7 @@ def sv(self):
def sv(self, dask_array):
collection = self.collection_map[self._collection_name()][1]
original = dask_array.behavior["__original_array__"]()[collection]
return original._apply_global_index(dask_array.sVIdxG, _dask_array_=original)
return original._apply_global_index(dask_array.sVIdxG)


_set_repr_name("AssociatedSV")
Expand Down

0 comments on commit 5079569

Please sign in to comment.