Skip to content

Commit

Permalink
delphes is probably daskified
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 16, 2023
1 parent d3a65a5 commit 36bf4c6
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
5 changes: 5 additions & 0 deletions coffea/nanoevents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from coffea.nanoevents.schemas import (
BaseSchema,
DelphesSchema,
NanoAODSchema,
PHYSLITESchema,
TreeMakerSchema,
Expand Down Expand Up @@ -241,6 +242,10 @@ def from_root(
from coffea.nanoevents.methods import physlite

behavior = physlite.behavior
elif schemaclass is DelphesSchema:
from coffea.nanoevents.methods import delphes

behavior = delphes.behavior

map_schema = _map_schema_uproot(
schemaclass=schemaclass,
Expand Down
2 changes: 2 additions & 0 deletions coffea/nanoevents/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class BaseSchema:
def __init__(self, base_form, *args, **kwargs):
params = dict(base_form.get("parameters", {}))
params["__record__"] = "NanoEvents"
if "metadata" in params and params["metadata"] is None:
params.pop("metadata")
params.setdefault("metadata", {})
self._form = {
"class": "RecordArray",
Expand Down
4 changes: 2 additions & 2 deletions coffea/nanoevents/schemas/delphes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DelphesSchema(BaseSchema):
- Any branches named ``{name}_size`` are assumed to be counts branches and converted to offsets ``o{name}``
"""

__dask_capable__ = False
__dask_capable__ = True

warn_missing_crossrefs = True

Expand Down Expand Up @@ -187,7 +187,7 @@ class DelphesSchema(BaseSchema):
"Zd": "Z coordinate of point of closest approach to vertex",
}

def __init__(self, base_form, version="latest"):
def __init__(self, base_form, version="latest", *args, **kwargs):
super().__init__(base_form)
self._version = version
if version == "latest":
Expand Down
39 changes: 25 additions & 14 deletions tests/test_nanoevents_delphes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import awkward as ak
import dask_awkward as dak
import pytest

from coffea.nanoevents import DelphesSchema, NanoEventsFactory
Expand All @@ -9,7 +10,7 @@
def _events():
path = os.path.abspath("tests/samples/delphes.root")
factory = NanoEventsFactory.from_root(
path, treepath="Delphes", schemaclass=DelphesSchema
path, treepath="Delphes", schemaclass=DelphesSchema, permit_dask=True
)
return factory.events()

Expand All @@ -20,7 +21,7 @@ def events():


def test_listify(events):
assert ak.to_list(events[0])
assert ak.to_list(events.CaloJet02[0].compute())


@pytest.mark.parametrize(
Expand Down Expand Up @@ -89,21 +90,25 @@ def test_collection_exists(events, collection):
],
)
def test_lorentz_vectorization(collection, events):
mask = ak.num(events[collection]) > 0
mask = dak.num(events[collection], axis=1) > 0
assert (
ak.parameters(events[collection][mask][0, 0].Area)["__record__"]
ak.parameters(events[collection][mask][0].Area._meta)["__record__"]
== "LorentzVector"
)
assert (
ak.parameters(events[collection][mask][0, 0].SoftDroppedJet)["__record__"]
ak.parameters(events[collection][mask][0].SoftDroppedJet._meta)["__record__"]
== "LorentzVector"
)
assert (
ak.parameters(events[collection][mask][0, 0].SoftDroppedSubJet1)["__record__"]
ak.parameters(events[collection][mask][0].SoftDroppedSubJet1._meta)[
"__record__"
]
== "LorentzVector"
)
assert (
ak.parameters(events[collection][mask][0, 0].SoftDroppedSubJet2)["__record__"]
ak.parameters(events[collection][mask][0].SoftDroppedSubJet2._meta)[
"__record__"
]
== "LorentzVector"
)

Expand Down Expand Up @@ -132,21 +137,27 @@ def test_lorentz_vectorization(collection, events):
],
)
def test_nested_lorentz_vectorization(collection, events):
mask = ak.num(events[collection]) > 0
assert ak.all(ak.num(events[collection].PrunedP4_5, axis=2) == 5)
mask = dak.num(events[collection], axis=1) > 0
assert ak.all(ak.num(events[collection].PrunedP4_5.compute(), axis=2) == 5)
assert (
ak.parameters(events[collection][mask].PrunedP4_5[0, 0, 0])["__record__"]
ak.parameters(events[collection][mask].PrunedP4_5[0, 0]._meta.layout.content)[
"__record__"
]
== "LorentzVector"
)

assert ak.all(ak.num(events[collection].SoftDroppedP4_5, axis=2) == 5)
assert ak.all(ak.num(events[collection].SoftDroppedP4_5.compute(), axis=2) == 5)
assert (
ak.parameters(events[collection][mask].SoftDroppedP4_5[0, 0, 0])["__record__"]
ak.parameters(
events[collection][mask].SoftDroppedP4_5[0, 0]._meta.layout.content
)["__record__"]
== "LorentzVector"
)

assert ak.all(ak.num(events[collection].TrimmedP4_5, axis=2) == 5)
assert ak.all(ak.num(events[collection].TrimmedP4_5.compute(), axis=2) == 5)
assert (
ak.parameters(events[collection][mask].TrimmedP4_5[0, 0, 0])["__record__"]
ak.parameters(events[collection][mask].TrimmedP4_5[0, 0]._meta.layout.content)[
"__record__"
]
== "LorentzVector"
)

0 comments on commit 36bf4c6

Please sign in to comment.