Skip to content

Commit

Permalink
jet resolution is daskified
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 15, 2023
1 parent c768962 commit 74f86b9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
3 changes: 2 additions & 1 deletion coffea/jetmet_tools/FactorizedJetCorrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def getSubCorrections(self, **kwargs):
one = dask_awkward.from_awkward(
awkward.Array(numpy.array(1.0, dtype=numpy.float32)), 1
)
newkwargs = kwargs
elif thetype is awkward.highlevel.Array:
for k, v in corrVars.items():
corrVars[k] = dask_awkward.from_awkward(v, 1)
Expand All @@ -218,7 +219,7 @@ def getSubCorrections(self, **kwargs):
for arg in sig
)

if isinstance(fargs[0], (awkward.highlevel.Array, dask_awkward.Array)):
if isinstance(fargs[0], dask_awkward.Array):
corrections.append(
dask_awkward.map_partitions(func, *fargs, meta=fargs[0]._meta)
)
Expand Down
19 changes: 14 additions & 5 deletions coffea/jetmet_tools/JetResolution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

import awkward
import dask_awkward
import numpy

from coffea.lookup_tools.jme_standard_function import jme_standard_function
Expand Down Expand Up @@ -147,15 +148,23 @@ def getResolution(self, **kwargs):
jrs = reso.getResolution(JetProperty1=jet.property1,...)
"""
# cache = kwargs.pop("lazy_cache", None)
# form = kwargs.pop("form", None)
thetype = type(kwargs[self.signature[0]])
newkwargs = {}
if thetype is awkward.highlevel.Array:
for k, v in kwargs.items():
newkwargs[k] = dask_awkward.from_awkward(v, 1)
else:
newkwargs = kwargs

resos = []
for i, func in enumerate(self._funcs):
sig = func.signature
args = tuple(kwargs[input] for input in sig)
args = tuple(newkwargs[inp] for inp in sig)

if isinstance(args[0], awkward.highlevel.Array):
resos.append(func(*args)) # update with dask laziness
if isinstance(args[0], dask_awkward.Array):
resos.append(
dask_awkward.map_partitions(func, *args, meta=args[0]._meta)
)
elif isinstance(args[0], numpy.ndarray):
resos.append(func(*args)) # np is non-lazy
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_jetmet_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_jet_resolution():
resos = reso.getResolution(JetEta=test_eta, Rho=test_Rho, JetPt=test_pt)
resos_jag = reso.getResolution(
JetEta=test_eta_jag, Rho=test_Rho_jag, JetPt=test_pt_jag
)
).compute()
assert ak.all(np.abs(resos - ak.flatten(resos_jag)) < 1e-6)

test_pt_jag = test_pt_jag[0:3]
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_jet_resolution():
)
resos_jag = reso.getResolution(
JetEta=test_eta_jag, Rho=test_Rho_jag, JetPt=test_pt_jag
)
).compute()
print("Reference Resolution (jagged):", resos_jag_ref)
print("Resolution (jagged):", resos_jag)
# NB: 5e-4 tolerance was agreed upon by lgray and aperloff, if the differences get bigger over time
Expand Down

0 comments on commit 74f86b9

Please sign in to comment.