Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support RNTuples with multiple cluster groups #1359

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/uproot/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ def _num_entries_for(in_ntuple, target_num_bytes, filter_name):
if "column" in key and "union" not in key:
key_nr = int(key.split("-")[1])
for cluster in range(start_cluster_idx, stop_cluster_idx):
pages = in_ntuple.ntuple.page_list_envelopes.pagelinklist[cluster][
key_nr
].pages
pages = in_ntuple.ntuple.page_link_list[cluster][key_nr].pages
total_bytes += sum(page.locator.num_bytes for page in pages)

total_entries = entry_stop
Expand Down Expand Up @@ -288,6 +286,8 @@ def read_members(self, chunk, cursor, context, file):
self._length = None

self._page_list_envelopes = []
self._cluster_summaries = None
self._page_link_list = None

self.ntuple = self

Expand Down Expand Up @@ -431,7 +431,19 @@ def footer(self):

@property
def cluster_summaries(self):
return self.page_list_envelopes.cluster_summaries
if self._cluster_summaries is None:
self._cluster_summaries = []
for pl in self.page_list_envelopes:
self._cluster_summaries.extend(pl.cluster_summaries)
return self._cluster_summaries

@property
def page_link_list(self):
if self._page_link_list is None:
self._page_link_list = []
for pl in self.page_list_envelopes:
self._page_link_list.extend(pl.pagelinklist)
return self._page_link_list

@property
def num_entries(self):
Expand Down Expand Up @@ -512,8 +524,8 @@ def page_list_envelopes(self):
decomp_chunk, cursor = self.read_locator(
loc, link.env_uncomp_size, context
)
self._page_list_envelopes = PageLink().read(
decomp_chunk, cursor, context
self._page_list_envelopes.append(
PageLink().read(decomp_chunk, cursor, context)
)

return self._page_list_envelopes
Expand Down Expand Up @@ -785,7 +797,7 @@ def read_col_pages(
return res

def read_col_page(self, ncol, cluster_i):
linklist = self.page_list_envelopes.pagelinklist[cluster_i]
linklist = self.ntuple.page_link_list[cluster_i]
# Check if the column is suppressed and pick the non-suppressed one if so
if ncol < len(linklist) and linklist[ncol].suppressed:
rel_crs = self._column_records_dict[self.column_records[ncol].field_id]
Expand Down Expand Up @@ -1203,9 +1215,6 @@ def read(self, chunk, cursor, context):
class FooterReader:
def __init__(self):
self.extension_header_links = RNTupleSchemaExtension()
self.cluster_summary_frames = ListFrameReader(
RecordFrameReader(ClusterSummaryReader())
)
self.cluster_group_record_frames = ListFrameReader(
RecordFrameReader(ClusterGroupRecordReader())
)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_1159_rntuple_cluster_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE

import pytest
import skhep_testdata

import uproot


def test_multiple_cluster_groups():
filename = skhep_testdata.data_path(
"test_multiple_cluster_groups_rntuple_v1-0-0-0.root"
)
with uproot.open(filename) as f:
obj = f["ntuple"]

assert len(obj.footer.cluster_group_records) == 3

assert obj.footer.cluster_group_records[0].num_clusters == 5
assert obj.footer.cluster_group_records[1].num_clusters == 4
assert obj.footer.cluster_group_records[2].num_clusters == 3

assert obj.num_entries == 1000

arrays = obj.arrays()

assert arrays.one.tolist() == list(range(1000))
assert arrays.int_vector.tolist() == [[i, i + 1] for i in range(1000)]
4 changes: 1 addition & 3 deletions tests/test_1191_rntuple_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def test_schema_extension():
with uproot.open(filename) as f:
obj = f["ntuple"]

assert len(obj.page_list_envelopes.pagelinklist[0]) < len(
obj.page_list_envelopes.pagelinklist[1]
)
assert len(obj.page_link_list[0]) < len(obj.page_link_list[1])

assert len(obj.column_records) > len(obj.header.column_records)
assert len(obj.column_records) == 4
Expand Down
8 changes: 4 additions & 4 deletions tests/test_1347_rntuple_floats_suppressed_cols.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def test_multiple_representations():
with uproot.open(filename) as f:
obj = f["ntuple"]

assert len(obj.page_list_envelopes.pagelinklist) == 3
assert len(obj.page_link_list) == 3
# The zeroth representation is active in clusters 0 and 2, but not in cluster 1
assert not obj.page_list_envelopes.pagelinklist[0][0].suppressed
assert obj.page_list_envelopes.pagelinklist[1][0].suppressed
assert not obj.page_list_envelopes.pagelinklist[2][0].suppressed
assert not obj.page_link_list[0][0].suppressed
assert obj.page_link_list[1][0].suppressed
assert not obj.page_link_list[2][0].suppressed

arrays = obj.arrays()

Expand Down
Loading