Skip to content

Commit

Permalink
more efficient distribution strategy for heavy lookups, first try
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 16, 2023
1 parent ff251c7 commit 4ce7c3f
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 30 deletions.
10 changes: 6 additions & 4 deletions coffea/lookup_tools/correctionlib_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import dask

from coffea.lookup_tools.lookup_base import lookup_base


class correctionlib_wrapper(lookup_base):
def __init__(self, payload):
super().__init__()
self._corr = payload
dask_future = dask.delayed(
self, pure=True, name=f"{self._corr.name}-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
return self._corr.evaluate(*args)
Expand All @@ -15,6 +20,3 @@ def __repr__(self):
for i, inp in enumerate(self._corr._base.inputs)
)
return f"correctionlib Correction: {self._corr.name}({signature})"

def __dask_tokenize__(self):
return (correctionlib_wrapper, self._corr.name, self._corr.version)
6 changes: 5 additions & 1 deletion coffea/lookup_tools/dense_evaluated_lookup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy

import dask
import numba
import numpy

Expand Down Expand Up @@ -30,7 +31,6 @@ def numbaize(fstr, varlist):
# methods for dealing with b-tag SFs
class dense_evaluated_lookup(lookup_base):
def __init__(self, values, dims, feval_dim=None):
super().__init__()
self._dimension = 0
whattype = type(dims)
if whattype == numpy.ndarray:
Expand Down Expand Up @@ -66,6 +66,10 @@ def __init__(self, values, dims, feval_dim=None):
"lookup_tools.evaluator only accepts 1D functions right now!"
)
self._feval_dim = feval_dim[0]
dask_future = dask.delayed(
self, pure=True, name=f"denseevallookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
indices = []
Expand Down
6 changes: 5 additions & 1 deletion coffea/lookup_tools/dense_lookup.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from copy import deepcopy

import dask
import numpy

from coffea.lookup_tools.lookup_base import lookup_base


class dense_lookup(lookup_base):
def __init__(self, values, dims, feval_dim=None):
super().__init__()
self._dimension = 0
whattype = type(dims)
if whattype == numpy.ndarray:
Expand All @@ -29,6 +29,10 @@ def __init__(self, values, dims, feval_dim=None):
if vals_are_strings:
raise Exception("dense_lookup cannot handle string values!")
self._values = deepcopy(values)
dask_future = dask.delayed(
self, pure=True, name=f"denselookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
if len(args) != self._dimension:
Expand Down
5 changes: 5 additions & 0 deletions coffea/lookup_tools/dense_mapped_lookup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numbers
from threading import Lock

import dask
import numba
import numpy

Expand All @@ -16,6 +17,10 @@ def __init__(self, axes, mapping, formulas, feval_dim):
self._mapping = mapping
self._formulas = formulas
self._feval_dim = feval_dim
dask_future = dask.delayed(
self, pure=True, name=f"densemappedlookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

@classmethod
def _compile(cls, formula):
Expand Down
6 changes: 5 additions & 1 deletion coffea/lookup_tools/jec_uncertainty_lookup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy

import dask
import numpy
from scipy.interpolate import interp1d

Expand Down Expand Up @@ -36,7 +37,6 @@ def __init__(self, formula, bins_and_orders, knots_and_vars):
The constructor takes the output of the "convert_junc_txt_file"
text file converter, which returns a formula, bins, and an interpolation table.
"""
super().__init__()
self._dim_order = bins_and_orders[1]
self._bins = bins_and_orders[0]
self._eval_vars = knots_and_vars[1]
Expand Down Expand Up @@ -78,6 +78,10 @@ def __init__(self, formula, bins_and_orders, knots_and_vars):
self._eval_args[argname] = i + len(self._dim_order)
if argname in self._dim_args.keys():
self._eval_args[argname] = self._dim_args[argname]
dask_future = dask.delayed(
self, pure=True, name=f"junclookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
"""uncertainties = f(args)"""
Expand Down
6 changes: 5 additions & 1 deletion coffea/lookup_tools/jersf_lookup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy

import dask
import numpy

from coffea.lookup_tools.lookup_base import lookup_base
Expand Down Expand Up @@ -32,7 +33,6 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
The constructor takes the output of the "convert_jersf_txt_file"
text file converter, which returns a formula, bins, and values.
"""
super().__init__()
self._dim_order = bins_and_orders[1]
self._bins = bins_and_orders[0]
self._eval_vars = clamps_and_vars[2]
Expand Down Expand Up @@ -65,6 +65,10 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
self._eval_args[argname] = i + len(self._dim_order)
if argname in self._dim_args.keys():
self._eval_args[argname] = self._dim_args[argname]
dask_future = dask.delayed(
self, pure=True, name=f"jersflookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
"""SFs = f(args)"""
Expand Down
6 changes: 5 additions & 1 deletion coffea/lookup_tools/jme_standard_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy

import awkward
import dask
import numpy
from numpy import sqrt # noqa: F401
from numpy import abs, exp, log, log10 # noqa: F401
Expand Down Expand Up @@ -92,7 +93,6 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
The constructor takes the output of the "convert_jec(jr)_txt_file"
text file converter, which returns a formula, bins, and parameter values.
"""
super().__init__()
self._dim_order = bins_and_orders[1]
self._bins = bins_and_orders[0]
self._eval_vars = clamps_and_vars[2]
Expand Down Expand Up @@ -130,6 +130,10 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders)
self._eval_args[argname] = i + len(self._dim_order)
if argname in self._dim_args.keys():
self._eval_args[argname] = self._dim_args[argname]
dask_future = dask.delayed(
self, pure=True, name=f"jmestandardlookup-{dask.base.tokenize(self)}"
).persist()
super().__init__(dask_future)

def _evaluate(self, *args, **kwargs):
"""jec/jer = f(args)"""
Expand Down
66 changes: 46 additions & 20 deletions coffea/lookup_tools/lookup_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import numbers
import weakref
from functools import partial

import awkward
import dask_awkward
import numpy


def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs):
def getfunction(
args, thelookup_dask=None, thelookup_wref=None, __pre_args__=tuple(), **kwargs
):
if not isinstance(args, (list, tuple)):
args = (args,)
if all(
Expand All @@ -17,6 +20,11 @@ def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs):
result = None
backend = awkward.backend(*args)
if backend == "cpu":
thelookup = None
if thelookup_wref is not None:
thelookup = thelookup_wref()
else:
thelookup = thelookup_dask.compute()
result = thelookup._evaluate(
*(list(__pre_args__) + [awkward.to_numpy(arg) for arg in args]),
**kwargs,
Expand All @@ -26,8 +34,7 @@ def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs):
for arg in args:
arg._touch_data(recursive=True)
zlargs.append(arg.form.length_zero_array())
zlargs = tuple(arg.form.length_zero_array() for arg in args)
result = thelookup._evaluate(
result = thelookup_wref()._evaluate(
*(list(__pre_args__) + [awkward.to_numpy(zlarg) for zlarg in zlargs]),
**kwargs,
)
Expand All @@ -43,23 +50,25 @@ def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs):


class _LookupXformFn:
def __init__(self, *args, thelookup, **kwargs):
def __init__(self, *args, thelookup_dask, thelookup_wref, **kwargs):
self.func = partial(
getfunction, thelookup=thelookup, __pre_args__=args, **kwargs
getfunction,
thelookup_dask=thelookup_dask,
thelookup_wref=thelookup_wref,
__pre_args__=args,
**kwargs,
)

def __call__(self, *args):
return awkward.transform(self.func, *args)

def __dask_tokenize__(self):
return (_LookupXformFn, self.func)


class lookup_base:
"""Base class for all objects that do some sort of value or function lookup"""

def __init__(self):
pass
def __init__(self, dask_future):
self._dask_future = dask_future
self._weakref = weakref.ref(self)

def __call__(self, *args, **kwargs):
dask_label = kwargs.pop("dask_label", None)
Expand All @@ -71,13 +80,18 @@ def __call__(self, *args, **kwargs):
actual_args = tuple(
arg for arg in args if isinstance(arg, dask_awkward.Array)
)
tomap = _LookupXformFn(*delay_args, thelookup=self, **kwargs)
tomap = _LookupXformFn(
*delay_args,
thelookup_dask=self._dask_future,
thelookup_wref=self._weakref,
**kwargs,
)

zlargs = [arg._meta.layout.form.length_zero_array() for arg in actual_args]
zlout = tomap(*zlargs)
meta = dask_awkward.typetracer_from_form(zlout.layout.form)

if dask_label:
if dask_label is not None:
return dask_awkward.map_partitions(
tomap,
*actual_args,
Expand All @@ -91,21 +105,33 @@ def __call__(self, *args, **kwargs):
meta=meta,
)

if all(isinstance(x, (numpy.ndarray, numbers.Number)) for x in args):
if all(isinstance(x, (numpy.ndarray, numbers.Number, str)) for x in args):
return self._evaluate(*args, **kwargs)
elif any(not isinstance(x, awkward.highlevel.Array) for x in args):
elif any(
not isinstance(x, (awkward.highlevel.Array, numbers.Number, str))
for x in args
):
raise TypeError(
"lookup base must receive high level awkward arrays,"
" numpy arrays, or numbers!"
" numpy arrays, strings, or numbers!"
)

# behavior = awkward._util.behavior_of(*args)
func = partial(getfunction, thelookup=self, **kwargs)
out = awkward.transform(func, *args)
non_array_args = tuple(
arg for arg in args if not isinstance(arg, awkward.highlevel.Array)
)
array_args = tuple(
arg for arg in args if isinstance(arg, awkward.highlevel.Array)
)
func = partial(
getfunction,
thelookup_dask=self._dask_future,
thelookup_wref=self._weakref,
__pre_args__=non_array_args,
**kwargs,
)
out = awkward.transform(func, *array_args)
return out

def __dask_tokenize__(self):
return (lookup_base, self)

def _evaluate(self, *args, **kwargs):
raise NotImplementedError
2 changes: 1 addition & 1 deletion tests/test_lookup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_evaluate_noimpl():
from coffea.lookup_tools.lookup_base import lookup_base

try:
lookup_base()._evaluate()
lookup_base(None)._evaluate()
except NotImplementedError:
pass

Expand Down

0 comments on commit 4ce7c3f

Please sign in to comment.