Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: RNTuple ug-fixing offset array concatenation, adding filter_name #1285

Merged
51 changes: 42 additions & 9 deletions src/uproot/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import struct
from collections import defaultdict
from itertools import accumulate

import numpy

Expand Down Expand Up @@ -60,8 +61,20 @@ def _keys(self):
keys.append(fr.field_name)
return keys

def keys(self):
return self._keys
def keys(
self,
*,
filter_name=None,
filter_typename=None,
recursive=False,
full_paths=True,
# TODO: some arguments might be missing when compared with TTree. Solve when blocker is present in dask/coffea.
):
if filter_name:
# Return keys from the filter_name list:
return [key for key in self._keys if key in filter_name]
else:
return self._keys

def read_members(self, chunk, cursor, context, file):
if uproot._awkwardforth.get_forth_obj(context) is not None:
Expand Down Expand Up @@ -480,10 +493,28 @@ def read_pagedesc(self, destination, desc, dtype_str, dtype, nbits, split):
# needed to chop off extra bits incase we used `unpackbits`
destination[:] = content[:num_elements]

def read_col_pages(self, ncol, cluster_range, pad_missing_ele=False):
res = numpy.concatenate(
[self.read_col_page(ncol, i) for i in cluster_range], axis=0
)
def read_col_pages(self, ncol, cluster_range, dtype_byte, pad_missing_ele=False):
arrays = [self.read_col_page(ncol, i) for i in cluster_range]

# Check if column stores offset values for jagged arrays (splitindex64) (applies to cardinality cols too):
if (
dtype_byte == uproot.const.rntuple_col_type_to_num_dict["splitindex64"]
or dtype_byte == uproot.const.rntuple_col_type_to_num_dict["splitindex32"]
):
# Extract the last offset values:
last_elements = [
arr[-1] for arr in arrays[:-1]
] # First value always zero, therefore skip first arr.
# Compute cumulative sum using itertools.accumulate:
last_offsets = list(accumulate(last_elements))
# Add the offsets to each array
for i in range(1, len(arrays)):
arrays[i] += last_offsets[i - 1]
# Remove the first element from every sub-array except for the first one:
arrays = [arrays[0]] + [arr[1:] for arr in arrays[1:]]

res = numpy.concatenate(arrays, axis=0)

if pad_missing_ele:
first_ele_index = self.column_records[ncol].first_ele_index
res = numpy.pad(res, (first_ele_index, 0))
Expand Down Expand Up @@ -530,8 +561,8 @@ def read_col_page(self, ncol, cluster_i):

def arrays(
self,
filter_names="*",
filter_typenames=None,
filter_name="*",
filter_typename=None,
entry_start=0,
entry_stop=None,
decompression_executor=None,
Expand All @@ -553,7 +584,7 @@ def arrays(
)

form = self.to_akform().select_columns(
filter_names, prune_unions_and_records=False
filter_name, prune_unions_and_records=False
)
# only read columns mentioned in the awkward form
target_cols = []
Expand All @@ -563,9 +594,11 @@ def arrays(
if "column" in key and "union" not in key:
key_nr = int(key.split("-")[1])
dtype_byte = self.column_records[key_nr].type

content = self.read_col_pages(
key_nr,
range(start_cluster_idx, stop_cluster_idx),
dtype_byte=dtype_byte,
pad_missing_ele=True,
)
if "cardinality" in key:
Expand Down
Loading