diff --git a/coffea/lookup_tools/correctionlib_wrapper.py b/coffea/lookup_tools/correctionlib_wrapper.py index 4019b6551..666fc17b7 100644 --- a/coffea/lookup_tools/correctionlib_wrapper.py +++ b/coffea/lookup_tools/correctionlib_wrapper.py @@ -1,10 +1,15 @@ +import dask + from coffea.lookup_tools.lookup_base import lookup_base class correctionlib_wrapper(lookup_base): def __init__(self, payload): - super().__init__() self._corr = payload + dask_future = dask.delayed( + self, pure=True, name=f"{self._corr.name}-{dask.base.tokenize(self)}" + ).persist() + super().__init__(dask_future) def _evaluate(self, *args, **kwargs): return self._corr.evaluate(*args) @@ -15,6 +20,3 @@ def __repr__(self): for i, inp in enumerate(self._corr._base.inputs) ) return f"correctionlib Correction: {self._corr.name}({signature})" - - def __dask_tokenize__(self): - return (correctionlib_wrapper, self._corr.name, self._corr.version) diff --git a/coffea/lookup_tools/dense_evaluated_lookup.py b/coffea/lookup_tools/dense_evaluated_lookup.py index 8b6c2a004..de5ab280e 100644 --- a/coffea/lookup_tools/dense_evaluated_lookup.py +++ b/coffea/lookup_tools/dense_evaluated_lookup.py @@ -1,5 +1,6 @@ from copy import deepcopy +import dask import numba import numpy @@ -30,7 +31,6 @@ def numbaize(fstr, varlist): # methods for dealing with b-tag SFs class dense_evaluated_lookup(lookup_base): def __init__(self, values, dims, feval_dim=None): - super().__init__() self._dimension = 0 whattype = type(dims) if whattype == numpy.ndarray: @@ -66,6 +66,10 @@ def __init__(self, values, dims, feval_dim=None): "lookup_tools.evaluator only accepts 1D functions right now!" ) self._feval_dim = feval_dim[0] + dask_future = dask.delayed( + self, pure=True, name=f"denseevallookup-{dask.base.tokenize(self)}" + ).persist() + super().__init__(dask_future) def _evaluate(self, *args, **kwargs): indices = [] diff --git a/coffea/lookup_tools/dense_lookup.py b/coffea/lookup_tools/dense_lookup.py index ba1a9b097..c67c9947a 100644 --- a/coffea/lookup_tools/dense_lookup.py +++ b/coffea/lookup_tools/dense_lookup.py @@ -1,5 +1,6 @@ from copy import deepcopy +import dask import numpy from coffea.lookup_tools.lookup_base import lookup_base @@ -7,7 +8,6 @@ class dense_lookup(lookup_base): def __init__(self, values, dims, feval_dim=None): - super().__init__() self._dimension = 0 whattype = type(dims) if whattype == numpy.ndarray: @@ -29,6 +29,10 @@ def __init__(self, values, dims, feval_dim=None): if vals_are_strings: raise Exception("dense_lookup cannot handle string values!") self._values = deepcopy(values) + dask_future = dask.delayed( + self, pure=True, name=f"denselookup-{dask.base.tokenize(self)}" + ).persist() + super().__init__(dask_future) def _evaluate(self, *args, **kwargs): if len(args) != self._dimension: diff --git a/coffea/lookup_tools/dense_mapped_lookup.py b/coffea/lookup_tools/dense_mapped_lookup.py index 4efd87129..26ed41439 100644 --- a/coffea/lookup_tools/dense_mapped_lookup.py +++ b/coffea/lookup_tools/dense_mapped_lookup.py @@ -1,6 +1,7 @@ import numbers from threading import Lock +import dask import numba import numpy @@ -16,6 +17,10 @@ def __init__(self, axes, mapping, formulas, feval_dim): self._mapping = mapping self._formulas = formulas self._feval_dim = feval_dim + dask_future = dask.delayed( + self, pure=True, name=f"densemappedlookup-{dask.base.tokenize(self)}" + ).persist() + super().__init__(dask_future) @classmethod def _compile(cls, formula): diff --git a/coffea/lookup_tools/jec_uncertainty_lookup.py b/coffea/lookup_tools/jec_uncertainty_lookup.py index 01fdb39b9..0d3e825fa 100644 --- a/coffea/lookup_tools/jec_uncertainty_lookup.py +++ b/coffea/lookup_tools/jec_uncertainty_lookup.py @@ -1,5 +1,6 @@ from copy import deepcopy +import dask import numpy from scipy.interpolate import interp1d @@ -36,7 +37,6 @@ def __init__(self, formula, bins_and_orders, knots_and_vars): The constructor takes the output of the "convert_junc_txt_file" text file converter, which returns a formula, bins, and an interpolation table. """ - super().__init__() self._dim_order = bins_and_orders[1] self._bins = bins_and_orders[0] self._eval_vars = knots_and_vars[1] @@ -78,6 +78,10 @@ def __init__(self, formula, bins_and_orders, knots_and_vars): self._eval_args[argname] = i + len(self._dim_order) if argname in self._dim_args.keys(): self._eval_args[argname] = self._dim_args[argname] + dask_future = dask.delayed( + self, pure=True, name=f"junclookup-{dask.base.tokenize(self)}" + ).persist() + super().__init__(dask_future) def _evaluate(self, *args, **kwargs): """uncertainties = f(args)""" diff --git a/coffea/lookup_tools/jersf_lookup.py b/coffea/lookup_tools/jersf_lookup.py index 35b59001e..9c85d7ff6 100644 --- a/coffea/lookup_tools/jersf_lookup.py +++ b/coffea/lookup_tools/jersf_lookup.py @@ -1,5 +1,6 @@ from copy import deepcopy +import dask import numpy from coffea.lookup_tools.lookup_base import lookup_base @@ -32,7 +33,6 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders) The constructor takes the output of the "convert_jersf_txt_file" text file converter, which returns a formula, bins, and values. """ - super().__init__() self._dim_order = bins_and_orders[1] self._bins = bins_and_orders[0] self._eval_vars = clamps_and_vars[2] @@ -65,6 +65,10 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders) self._eval_args[argname] = i + len(self._dim_order) if argname in self._dim_args.keys(): self._eval_args[argname] = self._dim_args[argname] + dask_future = dask.delayed( + self, pure=True, name=f"jersflookup-{dask.base.tokenize(self)}" + ).persist() + super().__init__(dask_future) def _evaluate(self, *args, **kwargs): """SFs = f(args)""" diff --git a/coffea/lookup_tools/jme_standard_function.py b/coffea/lookup_tools/jme_standard_function.py index 10d65be9f..6bcc71f86 100644 --- a/coffea/lookup_tools/jme_standard_function.py +++ b/coffea/lookup_tools/jme_standard_function.py @@ -1,6 +1,7 @@ from copy import deepcopy import awkward +import dask import numpy from numpy import sqrt # noqa: F401 from numpy import abs, exp, log, log10 # noqa: F401 @@ -92,7 +93,6 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders) The constructor takes the output of the "convert_jec(jr)_txt_file" text file converter, which returns a formula, bins, and parameter values. """ - super().__init__() self._dim_order = bins_and_orders[1] self._bins = bins_and_orders[0] self._eval_vars = clamps_and_vars[2] @@ -130,6 +130,10 @@ def __init__(self, formula, bins_and_orders, clamps_and_vars, params_and_orders) self._eval_args[argname] = i + len(self._dim_order) if argname in self._dim_args.keys(): self._eval_args[argname] = self._dim_args[argname] + dask_future = dask.delayed( + self, pure=True, name=f"jmestandardlookup-{dask.base.tokenize(self)}" + ).persist() + super().__init__(dask_future) def _evaluate(self, *args, **kwargs): """jec/jer = f(args)""" diff --git a/coffea/lookup_tools/lookup_base.py b/coffea/lookup_tools/lookup_base.py index 3c95eade4..c8b2f7614 100644 --- a/coffea/lookup_tools/lookup_base.py +++ b/coffea/lookup_tools/lookup_base.py @@ -1,4 +1,5 @@ import numbers +import weakref from functools import partial import awkward @@ -6,7 +7,9 @@ import numpy -def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs): +def getfunction( + args, thelookup_dask=None, thelookup_wref=None, __pre_args__=tuple(), **kwargs +): if not isinstance(args, (list, tuple)): args = (args,) if all( @@ -17,6 +20,11 @@ def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs): result = None backend = awkward.backend(*args) if backend == "cpu": + thelookup = None + if thelookup_wref is not None: + thelookup = thelookup_wref() + else: + thelookup = thelookup_dask.compute() result = thelookup._evaluate( *(list(__pre_args__) + [awkward.to_numpy(arg) for arg in args]), **kwargs, @@ -26,8 +34,7 @@ def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs): for arg in args: arg._touch_data(recursive=True) zlargs.append(arg.form.length_zero_array()) - zlargs = tuple(arg.form.length_zero_array() for arg in args) - result = thelookup._evaluate( + result = thelookup_wref()._evaluate( *(list(__pre_args__) + [awkward.to_numpy(zlarg) for zlarg in zlargs]), **kwargs, ) @@ -43,23 +50,25 @@ def getfunction(args, thelookup=None, __pre_args__=tuple(), **kwargs): class _LookupXformFn: - def __init__(self, *args, thelookup, **kwargs): + def __init__(self, *args, thelookup_dask, thelookup_wref, **kwargs): self.func = partial( - getfunction, thelookup=thelookup, __pre_args__=args, **kwargs + getfunction, + thelookup_dask=thelookup_dask, + thelookup_wref=thelookup_wref, + __pre_args__=args, + **kwargs, ) def __call__(self, *args): return awkward.transform(self.func, *args) - def __dask_tokenize__(self): - return (_LookupXformFn, self.func) - class lookup_base: """Base class for all objects that do some sort of value or function lookup""" - def __init__(self): - pass + def __init__(self, dask_future): + self._dask_future = dask_future + self._weakref = weakref.ref(self) def __call__(self, *args, **kwargs): dask_label = kwargs.pop("dask_label", None) @@ -71,13 +80,18 @@ def __call__(self, *args, **kwargs): actual_args = tuple( arg for arg in args if isinstance(arg, dask_awkward.Array) ) - tomap = _LookupXformFn(*delay_args, thelookup=self, **kwargs) + tomap = _LookupXformFn( + *delay_args, + thelookup_dask=self._dask_future, + thelookup_wref=self._weakref, + **kwargs, + ) zlargs = [arg._meta.layout.form.length_zero_array() for arg in actual_args] zlout = tomap(*zlargs) meta = dask_awkward.typetracer_from_form(zlout.layout.form) - if dask_label: + if dask_label is not None: return dask_awkward.map_partitions( tomap, *actual_args, @@ -91,21 +105,33 @@ def __call__(self, *args, **kwargs): meta=meta, ) - if all(isinstance(x, (numpy.ndarray, numbers.Number)) for x in args): + if all(isinstance(x, (numpy.ndarray, numbers.Number, str)) for x in args): return self._evaluate(*args, **kwargs) - elif any(not isinstance(x, awkward.highlevel.Array) for x in args): + elif any( + not isinstance(x, (awkward.highlevel.Array, numbers.Number, str)) + for x in args + ): raise TypeError( "lookup base must receive high level awkward arrays," - " numpy arrays, or numbers!" + " numpy arrays, strings, or numbers!" ) # behavior = awkward._util.behavior_of(*args) - func = partial(getfunction, thelookup=self, **kwargs) - out = awkward.transform(func, *args) + non_array_args = tuple( + arg for arg in args if not isinstance(arg, awkward.highlevel.Array) + ) + array_args = tuple( + arg for arg in args if isinstance(arg, awkward.highlevel.Array) + ) + func = partial( + getfunction, + thelookup_dask=self._dask_future, + thelookup_wref=self._weakref, + __pre_args__=non_array_args, + **kwargs, + ) + out = awkward.transform(func, *array_args) return out - def __dask_tokenize__(self): - return (lookup_base, self) - def _evaluate(self, *args, **kwargs): raise NotImplementedError diff --git a/tests/test_lookup_tools.py b/tests/test_lookup_tools.py index f4d20f387..ce995553b 100644 --- a/tests/test_lookup_tools.py +++ b/tests/test_lookup_tools.py @@ -137,7 +137,7 @@ def test_evaluate_noimpl(): from coffea.lookup_tools.lookup_base import lookup_base try: - lookup_base()._evaluate() + lookup_base(None)._evaluate() except NotImplementedError: pass