Skip to content

Commit

Permalink
all individual correction components daskified
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 15, 2023
1 parent 74f86b9 commit 8265bf7
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 26 deletions.
7 changes: 6 additions & 1 deletion coffea/jetmet_tools/FactorizedJetCorrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,12 @@ def getSubCorrections(self, **kwargs):

if isinstance(fargs[0], dask_awkward.Array):
corrections.append(
dask_awkward.map_partitions(func, *fargs, meta=fargs[0]._meta)
dask_awkward.map_partitions(
func,
*fargs,
label=f"{self._campaign}-{self._dataera}-{self._datatype}-{self._levels[i]}-{self._jettype}",
meta=fargs[0]._meta,
)
)
elif isinstance(fargs[0], numpy.ndarray):
corrections.append(func(*fargs)) # np is non-lazy
Expand Down
30 changes: 25 additions & 5 deletions coffea/jetmet_tools/JetCorrectionUncertainty.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

import awkward
import dask_awkward
import numpy

from coffea.lookup_tools.jec_uncertainty_lookup import jec_uncertainty_lookup
Expand Down Expand Up @@ -156,14 +157,33 @@ def getUncertainty(self, **kwargs):
#in a zip iterator
"""
# cache = kwargs.pop("lazy_cache", None)
thetype = type(kwargs[self.signature[0]])
newkwargs = {}
if thetype is awkward.highlevel.Array:
for k, v in kwargs.items():
newkwargs[k] = dask_awkward.from_awkward(v, 1)
else:
newkwargs = kwargs

uncs = []
out_form = awkward.forms.ListOffsetForm(
"i64", awkward.forms.NumpyForm("float32", inner_shape=(2,))
)
out_meta = dask_awkward.typetracer_from_form(out_form)

for i, func in enumerate(self._funcs):
sig = func.signature
args = tuple(kwargs[input] for input in sig)

if isinstance(args[0], awkward.highlevel.Array):
uncs.append(func(*args)) # upgrade this with dask laziness
args = tuple(newkwargs[inp] for inp in sig)

if isinstance(args[0], dask_awkward.Array):
uncs.append(
dask_awkward.map_partitions(
func,
*args,
label=f"{self._campaign}-{self._dataera}-{self._datatype}-{self._levels[i]}-{self._jettype}-uncertainty",
meta=out_meta,
)
)
elif isinstance(args[0], numpy.ndarray):
uncs.append(func(*args)) # np is non-lazy
else:
Expand Down
7 changes: 6 additions & 1 deletion coffea/jetmet_tools/JetResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def getResolution(self, **kwargs):

if isinstance(args[0], dask_awkward.Array):
resos.append(
dask_awkward.map_partitions(func, *args, meta=args[0]._meta)
dask_awkward.map_partitions(
func,
*args,
label=f"{self._campaign}-{self._dataera}-{self._datatype}-{self._levels[i]}-{self._jettype}-resolution",
meta=args[0]._meta,
)
)
elif isinstance(args[0], numpy.ndarray):
resos.append(func(*args)) # np is non-lazy
Expand Down
32 changes: 26 additions & 6 deletions coffea/jetmet_tools/JetResolutionScaleFactor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

import awkward
import dask_awkward
import numpy

from coffea.lookup_tools.jersf_lookup import jersf_lookup
Expand Down Expand Up @@ -143,15 +144,34 @@ def getScaleFactor(self, **kwargs):
jersfs = jersf.getScaleFactor(JetProperty1=jet.property1,...)
"""
# cache = kwargs.pop("lazy_cache", None)
# form = kwargs.pop("form", None)

thetype = type(kwargs[self.signature[0]])
newkwargs = {}
if thetype is awkward.highlevel.Array:
for k, v in kwargs.items():
newkwargs[k] = dask_awkward.from_awkward(v, 1)
else:
newkwargs = kwargs

out_form = awkward.forms.ListOffsetForm(
"i64", awkward.forms.NumpyForm("float32", inner_shape=(3,))
)
out_meta = dask_awkward.typetracer_from_form(out_form)

sfs = []
for i, func in enumerate(self._funcs):
sig = func.signature
args = tuple(kwargs[input] for input in sig)

if isinstance(args[0], awkward.highlevel.Array):
sfs.append(func(*args)) # update this with dask
args = tuple(newkwargs[inp] for inp in sig)

if isinstance(args[0], dask_awkward.Array):
sfs.append(
dask_awkward.map_partitions(
func,
*args,
label=f"{self._campaign}-{self._dataera}-{self._datatype}-{self._levels[i]}-{self._jettype}-resolution-scale-factor",
meta=out_meta,
)
)
elif isinstance(args[0], numpy.ndarray):
sfs.append(func(*args)) # np is non-lazy
else:
Expand Down
41 changes: 28 additions & 13 deletions tests/test_jetmet_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_factorized_jet_corrector():
corrs_L1_jag = corrector.getCorrection(
JetEta=test_eta_jag, Rho=test_Rho_jag, JetPt=test_pt_jag, JetA=test_A_jag
).compute()

print("Reference L1 corrections:", corrs_L1_jag_ref)
print("Calculated L1 corrections:", corrs_L1_jag)
assert ak.all(
Expand Down Expand Up @@ -180,6 +181,7 @@ def test_factorized_jet_corrector():
corrs_L1L2L3_jag = corrector.getCorrection(
JetEta=test_eta_jag, Rho=test_Rho_jag, JetPt=test_pt_jag, JetA=test_A_jag
).compute()

print("Reference L1L2L3 corrections:", corrs_L1L2L3_jag_ref)
print("Calculated L1L2L3 corrections:", corrs_L1L2L3_jag)
assert ak.all(
Expand Down Expand Up @@ -272,7 +274,7 @@ def test_jet_correction_uncertainty():

for i, (level, corrs) in enumerate(juncs):
assert corrs.shape[0] == test_eta.shape[0]
assert ak.all(corrs == ak.flatten(juncs_jag[i][1]))
assert ak.all(corrs == ak.flatten(juncs_jag[i][1].compute()))

test_pt_jag = test_pt_jag[0:3]
test_eta_jag = test_eta_jag[0:3]
Expand All @@ -299,11 +301,14 @@ def test_jet_correction_uncertainty():
juncs_jag = list(junc.getUncertainty(JetEta=test_eta_jag, JetPt=test_pt_jag))

for i, (level, corrs) in enumerate(juncs_jag):
materialized_corrs = corrs.compute()
print("Index:", i)
print("Correction level:", level)
print("Reference Uncertainties (jagged):", juncs_jag_ref)
print("Uncertainties (jagged):", corrs)
assert ak.all(np.abs(ak.flatten(juncs_jag_ref) - ak.flatten(corrs)) < 1e-6)
print("Uncertainties (jagged):", materialized_corrs)
assert ak.all(
np.abs(ak.flatten(juncs_jag_ref) - ak.flatten(materialized_corrs)) < 1e-6
)


def test_jet_correction_uncertainty_sources():
Expand Down Expand Up @@ -335,7 +340,7 @@ def test_jet_correction_uncertainty_sources():
for i, (level, corrs) in enumerate(juncs):
assert level in levels
assert corrs.shape[0] == test_eta.shape[0]
assert ak.all(corrs == ak.flatten(juncs_jag[i][1]))
assert ak.all(corrs == ak.flatten(juncs_jag[i][1].compute()))

test_pt_jag = test_pt_jag[0:3]
test_eta_jag = test_eta_jag[0:3]
Expand Down Expand Up @@ -363,11 +368,14 @@ def test_jet_correction_uncertainty_sources():
for i, (level, corrs) in enumerate(juncs_jag):
if level != "Total":
continue
materialized_corrs = corrs.compute()
print("Index:", i)
print("Correction level:", level)
print("Reference Uncertainties (jagged):", juncs_jag_ref)
print("Uncertainties (jagged):", corrs, "\n")
assert ak.all(np.abs(ak.flatten(juncs_jag_ref) - ak.flatten(corrs)) < 1e-6)
print("Uncertainties (jagged):", materialized_corrs, "\n")
assert ak.all(
np.abs(ak.flatten(juncs_jag_ref) - ak.flatten(materialized_corrs)) < 1e-6
)


def test_jet_correction_regrouped_uncertainty_sources():
Expand Down Expand Up @@ -396,7 +404,7 @@ def test_jet_correction_regrouped_uncertainty_sources():
for i, tpl in enumerate(list(junc.getUncertainty(JetEta=test_eta, JetPt=test_pt))):
assert tpl[0] in levels
assert tpl[1].shape[0] == test_eta.shape[0]
assert ak.all(tpl[1] == ak.flatten(juncs_jag[i][1]))
assert ak.all(tpl[1] == ak.flatten(juncs_jag[i][1].compute()))

test_pt_jag = test_pt_jag[0:3]
test_eta_jag = test_eta_jag[0:3]
Expand Down Expand Up @@ -424,11 +432,14 @@ def test_jet_correction_regrouped_uncertainty_sources():
for i, (level, corrs) in enumerate(juncs_jag):
if level != "Total":
continue
materialized_corrs = corrs.compute()
print("Index:", i)
print("Correction level:", level)
print("Reference Uncertainties (jagged):", juncs_jag_ref)
print("Uncertainties (jagged):", corrs, "\n")
assert ak.all(np.abs(ak.flatten(juncs_jag_ref) - ak.flatten(corrs)) < 1e-6)
print("Uncertainties (jagged):", materialized_corrs, "\n")
assert ak.all(
np.abs(ak.flatten(juncs_jag_ref) - ak.flatten(materialized_corrs)) < 1e-6
)


def test_jet_resolution_sf():
Expand All @@ -448,7 +459,7 @@ def test_jet_resolution_sf():
assert resosf.getScaleFactor(JetEta=test_eta[:0]).shape == (0, 3)

resosfs = resosf.getScaleFactor(JetEta=test_eta)
resosfs_jag = resosf.getScaleFactor(JetEta=test_eta_jag)
resosfs_jag = resosf.getScaleFactor(JetEta=test_eta_jag).compute()
assert ak.all(resosfs == ak.flatten(resosfs_jag))

test_pt_jag = test_pt_jag[0:3]
Expand All @@ -473,7 +484,7 @@ def test_jet_resolution_sf():
),
counts,
)
resosfs_jag = resosf.getScaleFactor(JetEta=test_eta_jag)
resosfs_jag = resosf.getScaleFactor(JetEta=test_eta_jag).compute()
print("Reference Resolution SF (jagged):", resosfs_jag_ref)
print("Resolution SF (jagged):", resosfs_jag)
assert ak.all(np.abs(ak.flatten(resosfs_jag_ref) - ak.flatten(resosfs_jag)) < 1e-6)
Expand All @@ -497,7 +508,9 @@ def test_jet_resolution_sf_2d():
assert resosf.getScaleFactor(JetPt=test_pt[:0], JetEta=test_eta[:0]).shape == (0, 3)

resosfs = resosf.getScaleFactor(JetPt=test_pt, JetEta=test_eta)
resosfs_jag = resosf.getScaleFactor(JetPt=test_pt_jag, JetEta=test_eta_jag)
resosfs_jag = resosf.getScaleFactor(
JetPt=test_pt_jag, JetEta=test_eta_jag
).compute()
assert ak.all(resosfs == ak.flatten(resosfs_jag))

test_pt_jag = test_pt_jag[0:3]
Expand All @@ -522,7 +535,9 @@ def test_jet_resolution_sf_2d():
),
counts,
)
resosfs_jag = resosf.getScaleFactor(JetPt=test_pt_jag, JetEta=test_eta_jag)
resosfs_jag = resosf.getScaleFactor(
JetPt=test_pt_jag, JetEta=test_eta_jag
).compute()
print("Reference Resolution SF (jagged):", resosfs_jag_ref)
print("Resolution SF (jagged):", resosfs_jag)
assert ak.all(np.abs(ak.flatten(resosfs_jag_ref) - ak.flatten(resosfs_jag)) < 1e-6)
Expand Down

0 comments on commit 8265bf7

Please sign in to comment.