Skip to content

Commit

Permalink
Merge pull request #431 from European-XFEL/feat/sd-train-id-coords
Browse files Browse the repository at this point in the history
Add SourceData.train_id_coordinates()
  • Loading branch information
philsmt authored Jul 26, 2023
2 parents e6a20d9 + 1863baf commit 1f73126
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 12 deletions.
57 changes: 45 additions & 12 deletions extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def _get_first_source_file(self):

return FileAccess(sample_path)

def _get_index_group_sample(self, index_group):
if self.is_control and not index_group:
# Shortcut for CONTROL data.
return self.one_key()

for key in self.keys():
if self[key].index_group == index_group:
return key

raise ValueError(f'{index_group} not an index group of this source')

@property
def storage_class(self):
if self._first_source_file is ...:
Expand Down Expand Up @@ -312,12 +323,12 @@ def data_counts(self, labelled=True, index_group=None):
"""

if index_group is None:
sample_keys = dict(zip(
[self[key].index_group for key in self.keys()], self.keys()))

# Collect data counts for a sample key per index group.
data_counts = {
prefix: self[key].data_counts(labelled=labelled)
for prefix, key in sample_keys.items()
index_group: self[
self._get_index_group_sample(index_group)
].data_counts(labelled=labelled)
for index_group in self.index_groups
}

if labelled:
Expand All @@ -327,14 +338,36 @@ def data_counts(self, labelled=True, index_group=None):
return np.stack(list(data_counts.values())).max(axis=0)

else:
for key in self.keys():
if self[key].index_group == index_group:
break
else:
raise ValueError(f'{index_group} not an index group of this '
f'source')
return self[self._get_index_group_sample(index_group)] \
.data_counts(labelled=labelled)

return self[key].data_counts(labelled=labelled)
def train_id_coordinates(self, index_group=None):
"""Make an array of train IDs to use alongside data this source.
If *index_group* is omitted, the shared train ID coordinates
across all index groups is returned if there is one. Unlike for
``.data_counts()``, an exception is raised if the train ID
coordinates (and thus data counts) differ among the index groups.
"""

if index_group is None:
if len(self.index_groups) > 1:
# Verify that a common train ID coordinate exists for
# multiple index groups. The reads necessary for this
# operation are identical to those for the train ID
# coordinates themselves.
counts_per_group = np.stack([
self.data_counts(labelled=False, index_group=index_group)
for index_group in self.index_groups])

if (counts_per_group != counts_per_group[0]).any():
raise ValueError('source has index groups with differing '
'data counts')

index_group = self.index_groups.pop()

return self[self._get_index_group_sample(index_group)] \
.train_id_coordinates()

def run_metadata(self) -> Dict:
"""Get a dictionary of metadata about the run
Expand Down
46 changes: 46 additions & 0 deletions extra_data/tests/test_sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,49 @@ def test_drop_empty_trains(mock_reduced_spb_proc_run):

with pytest.raises(ValueError):
am0.drop_empty_trains(index_group='preamble')


def test_train_id_coordinates(mock_reduced_spb_proc_run):
run = RunDirectory(mock_reduced_spb_proc_run)


# control data.
xgm = run['SPB_XTD9_XGM/DOOCS/MAIN']

np.testing.assert_equal(
xgm.train_id_coordinates(),
xgm.train_id_coordinates(''))
np.testing.assert_equal(
xgm.train_id_coordinates(),
xgm['pulseEnergy.conversion'].train_id_coordinates())

with pytest.raises(ValueError):
xgm.train_id_coordinates('data')

# instrument data.
camera = run['SPB_IRU_CAM/CAM/SIDEMIC:daqOutput']

np.testing.assert_equal(
camera.train_id_coordinates(),
camera.train_id_coordinates('data'))
np.testing.assert_equal(
camera.train_id_coordinates(),
camera['data.image.pixels'].train_id_coordinates())

with pytest.raises(ValueError):
camera.train_id_coordinates('image')

# xtdf data.
am0 = run['SPB_DET_AGIPD1M-1/DET/0CH0:xtdf']

np.testing.assert_equal(
am0.train_id_coordinates('header'),
am0['header.pulseCount'].train_id_coordinates())

np.testing.assert_equal(
am0.train_id_coordinates('image'),
am0['image.data'].train_id_coordinates())

# Should fail due to multiple index groups with differing counts.
with pytest.raises(ValueError):
am0.train_id_coordinates()

0 comments on commit 1f73126

Please sign in to comment.