From 0cd194014272d4b0b1fcc348564c91a61adb1580 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Wed, 18 Oct 2023 22:11:35 +0100 Subject: [PATCH] fix: cache form remapping to avoid per-chunk workload --- src/uproot/_dask.py | 89 +++++++++++++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 31 deletions(-) diff --git a/src/uproot/_dask.py b/src/uproot/_dask.py index 3d77d5f78..ae8ee8b69 100644 --- a/src/uproot/_dask.py +++ b/src/uproot/_dask.py @@ -885,8 +885,9 @@ def __call__(self, form: Form) -> tuple[Form, TrivialFormMappingInfo]: class UprootReadMixin: - form_mapping: ImplementsFormMapping base_form: Form + expected_form: Form + form_mapping_info: ImplementsFormMappingInfo common_keys: frozenset[str] interp_options: dict[str, Any] @@ -898,24 +899,24 @@ def read_tree(self, tree: HasBranches, start: int, stop: int) -> AwkArray: awkward = uproot.extras.awkward() nplike = Numpy.instance() - form, form_info = self.form_mapping(self.base_form) - # The remap implementation should correctly populate the generated # buffer mapping in __call__, such that the high-level form can be # used in `from_buffers` - mapping = form_info.load_buffers( + mapping = self.form_mapping_info.load_buffers( tree, self.common_keys, start, stop, self.interp_options ) # Populate container with placeholders if keys aren't required # Otherwise, read from disk container = {} - for buffer_key, dtype in form.expected_from_buffers( - buffer_key=form_info.buffer_key + for buffer_key, dtype in self.expected_form.expected_from_buffers( + buffer_key=self.form_mapping_info.buffer_key ).items(): # Which key(s) does this buffer require. This code permits the caller # to require multiple keys to compute a single buffer. - keys_for_buffer = form_info.keys_for_buffer_keys(frozenset({buffer_key})) + keys_for_buffer = self.form_mapping_info.keys_for_buffer_keys( + frozenset({buffer_key}) + ) # If reading this buffer loads a permitted key, read from the tree # We might not have _all_ keys if e.g. buffer A requires one # but not two of the keys required for buffer B @@ -930,20 +931,19 @@ def read_tree(self, tree: HasBranches, start: int, stop: int) -> AwkArray: ) return awkward.from_buffers( - form, + self.expected_form, stop - start, container, - behavior=form_info.behavior, - buffer_key=form_info.buffer_key, + behavior=self.form_mapping_info.behavior, + buffer_key=self.form_mapping_info.buffer_key, ) def mock(self) -> AwkArray: awkward = uproot.extras.awkward() - high_level_form, form_info = self.form_mapping(self.base_form) return awkward.typetracer.typetracer_from_form( - high_level_form, + self.expected_form, highlevel=True, - behavior=form_info.behavior, + behavior=self.form_mapping_info.behavior, ) def prepare_for_projection(self) -> tuple[AwkArray, TypeTracerReport, dict]: @@ -952,14 +952,13 @@ def prepare_for_projection(self) -> tuple[AwkArray, TypeTracerReport, dict]: # A form mapping will (may) remap the base form into a new form # The remapped form can be queried for structural information - high_level_form, form_info = self.form_mapping(self.base_form) # Build typetracer and associated report object meta, report = awkward.typetracer.typetracer_with_report( - high_level_form, + self.expected_form, highlevel=True, - behavior=form_info.behavior, - buffer_key=form_info.buffer_key, + behavior=self.form_mapping_info.behavior, + buffer_key=self.form_mapping_info.buffer_key, ) return ( @@ -967,10 +966,10 @@ def prepare_for_projection(self) -> tuple[AwkArray, TypeTracerReport, dict]: report, { "trace": dask_awkward.lib.utils.trace_form_structure( - high_level_form, - buffer_key=form_info.buffer_key, + self.expected_form, + buffer_key=self.form_mapping_info.buffer_key, ), - "form_info": form_info, + "form_info": self.form_mapping_info, }, ) @@ -1016,20 +1015,27 @@ class _UprootRead(UprootReadMixin): def __init__( self, ttrees, - common_keys, - interp_options, - form_mapping: ImplementsFormMapping, - base_form, + common_keys: frozenset[str], + interp_options: dict[str, Any], + base_form: Form, + expected_form: Form, + form_mapping_info: ImplementsFormMappingInfo, ) -> None: self.ttrees = ttrees self.common_keys = frozenset(common_keys) self.interp_options = interp_options - self.form_mapping = form_mapping self.base_form = base_form + self.expected_form = expected_form + self.form_mapping_info = form_mapping_info def project_keys(self: T, keys: frozenset[str]) -> T: return _UprootRead( - self.ttrees, keys, self.interp_options, self.form_mapping, self.base_form + self.ttrees, + keys, + self.interp_options, + self.base_form, + self.expected_form, + self.form_mapping_info, ) def __call__(self, i_start_stop) -> AwkArray: @@ -1046,16 +1052,18 @@ def __init__( real_options, common_keys, interp_options, - form_mapping: ImplementsFormMapping, base_form: Form, + expected_form: Form, + form_mapping_info: ImplementsFormMappingInfo, ) -> None: self.custom_classes = custom_classes self.allow_missing = allow_missing self.real_options = real_options self.common_keys = frozenset(common_keys) self.interp_options = interp_options - self.form_mapping = form_mapping self.base_form = base_form + self.expected_form = expected_form + self.form_mapping_info = form_mapping_info def __call__(self, blockwise_args) -> AwkArray: ( @@ -1104,8 +1112,9 @@ def project_keys(self: T, keys: frozenset[str]) -> T: self.real_options, keys, self.interp_options, - self.form_mapping, self.base_form, + self.expected_form, + self.form_mapping_info, ) @@ -1289,13 +1298,22 @@ def real_filter_branch(branch): divisions.append(0) partition_args.append((0, 0, 0)) + if form_mapping is None: + expected_form = dask_awkward.lib.utils.form_with_unique_keys( + base_form, "" + ) + form_mapping_info = TrivialFormMappingInfo(expected_form) + else: + expected_form, form_mapping_info = form_mapping(base_form) + return dask_awkward.from_map( _UprootRead( ttrees, common_keys, interp_options, - form_mapping=TrivialFormMapping() if form_mapping is None else form_mapping, base_form=base_form, + expected_form=expected_form, + form_mapping_info=form_mapping_info, ), partition_args, divisions=tuple(divisions), @@ -1370,6 +1388,14 @@ def _get_dak_array_delay_open( ) ) + if form_mapping is None: + expected_form = dask_awkward.lib.utils.form_with_unique_keys( + base_form, "" + ) + form_mapping_info = TrivialFormMappingInfo(expected_form) + else: + expected_form, form_mapping_info = form_mapping(base_form) + return dask_awkward.from_map( _UprootOpenAndRead( custom_classes, @@ -1377,8 +1403,9 @@ def _get_dak_array_delay_open( real_options, common_keys, interp_options, - form_mapping=TrivialFormMapping() if form_mapping is None else form_mapping, base_form=base_form, + expected_form=expected_form, + form_mapping_info=form_mapping_info, ), partition_args, divisions=None if divisions is None else tuple(divisions),