From 5d602141c5bcd33f82bfa6c535d0687f7f57a6a2 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Tue, 16 Aug 2022 18:23:03 +0100 Subject: [PATCH 01/16] Start creating KeyData interface for multi-module detector data --- extra_data/components.py | 496 +++++++++++++++++++++++---------------- 1 file changed, 297 insertions(+), 199 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index de37663f..64e2a766 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -90,6 +90,32 @@ def _check_pulse_selection(pulses): return type(pulses)(val) +def _select_pulse_ids(pulses, data_pulse_ids): + """Select pulses by ID across a chunk of trains + + Returns a boolean array of which entries in data_pulse_ids match. + """ + if isinstance(pulses.value, slice): + s = pulses.value + desired = np.arange(s.start, s.stop, step=s.step, dtype=np.uint64) + else: + desired = pulses.value + + return np.isin(data_pulse_ids, desired) + + +def _out_array(shape, dtype, fill_value=None): + if fill_value is None: + fill_value = np.nan if dtype.kind == 'f' else 0 + fill_value = dtype.type(fill_value) + + # Zeroed memory can be allocated faster than explicitly writing zeros + if fill_value == 0: + return np.zeros(shape, dtype=dtype) + else: + return np.full(shape, fill_value, dtype=dtype) + + class MultimodDetectorBase: """Base class for detectors made of several modules as separate data sources """ @@ -98,6 +124,7 @@ class MultimodDetectorBase: # Override in subclass _main_data_key = '' # Key to use for checking data counts match _frames_per_entry = 1 # Override if separate pulse dimension in files + _modnos_start_at = 0 # Override if module numbers start at 1 (JUNGFRAU) module_shape = (0, 0) n_modules = 0 @@ -147,6 +174,9 @@ def __init__(self, data: DataCollection, detector_name=None, modules=None, # If we add extra instance attributes, check whether they should be # updated in .select_trains() below. + def __getitem__(self, item): + return MultimodKeyData(self, item) + @classmethod def _find_detector_name(cls, data): detector_names = set() @@ -382,17 +412,6 @@ def _concat(arrays, index, fill_value, astype): fill_value=fill_value ) - @staticmethod - def _out_array(shape, dtype, fill_value=None): - if fill_value is None: - fill_value = np.nan if dtype.kind == 'f' else 0 - fill_value = dtype.type(fill_value) - - if fill_value == 0: - return np.zeros(shape, dtype=dtype) - else: - return np.full(shape, fill_value, dtype=dtype) - def get_array(self, key, *, fill_value=None, roi=(), astype=None): """Get a labelled array of detector data @@ -411,36 +430,7 @@ def get_array(self, key, *, fill_value=None, roi=(), astype=None): Data type of the output array. If None (default) the dtype matches the input array dtype """ - train_ids = np.asarray(self.data.train_ids) - - eg_src = min(self.source_to_modno) - eg_keydata = self.data[eg_src, key] - - # Find the shape of 1 frame for 1 module with the ROI applied - out_shape = ((len(self.modno_to_source), len(train_ids)) - + roi_shape(eg_keydata.entry_shape, roi)) - - dtype = eg_keydata.dtype if astype is None else np.dtype(astype) - out = self._out_array(out_shape, dtype, fill_value=fill_value) - - - modnos = [] - for mod_ix, (modno, source) in enumerate(sorted(self.modno_to_source.items())): - for chunk in self.data._find_data_chunks(source, key): - for tgt_slice, chunk_slice in self._split_align_chunk(chunk, train_ids): - chunk.dataset.read_direct( - out[mod_ix, tgt_slice], source_sel=(chunk_slice,) + roi - ) - - modnos.append(modno) - - # Dimension labels - dims = ['module', 'trainId'] + ['dim_%d' % i for i in range(out.ndim - 2)] - - # Train ID index - coords = {'module': modnos, 'trainId': train_ids} - - return xarray.DataArray(out, dims=dims, coords=coords) + return self[key].xarray(fill_value=fill_value, roi=roi, astype=astype) def get_dask_array(self, key, fill_value=None, astype=None): @@ -499,19 +489,12 @@ def __init__(self, data: DataCollection, detector_name=None, modules=None, *, min_modules=1): super().__init__(data, detector_name, modules, min_modules=min_modules) - @staticmethod - def _select_pulse_ids(pulses, data_pulse_ids): - """Select pulses by ID across a chunk of trains - - Returns an array or slice of the indexes to include. - """ - if isinstance(pulses.value, slice): - s = pulses.value - desired = np.arange(s.start, s.stop, step=s.step, dtype=np.uint64) - else: - desired = pulses.value + def __getitem__(self, item): + if item.startswith('image.'): + return XtdfImageMultimodKeyData(self, item) + return super().__getitem__(item) - return np.isin(data_pulse_ids, desired) + # Several methods below are overridden in LPD1M for parallel gain mode @staticmethod def _select_pulse_indices(pulses, counts): @@ -554,52 +537,6 @@ def _make_image_index(self, tids, inner_ids, inner_name='pulse'): [tids, inner_ids], names=['train', inner_name] ) - @staticmethod - def _guess_axes(data, train_pulse_ids, unstack_pulses, modnos=None): - shape = data.shape - if modnos is not None: - shape = shape[1:] - ndim = len(shape) - - # Raw files have a spurious extra dimension - if ndim >= 2 and shape[1] == 1: - if modnos is None: - data = data[:, 0] - else: - data = data[:, :, 0] - ndim -= 1 - - # TODO: this assumes we can tell what the axes are just from the - # number of dimensions. Works for the data we've seen, but we - # should look for a more reliable way. - if ndim == 4: - # image.data in raw data - dims = ['train_pulse', 'data_gain', 'slow_scan', 'fast_scan'] - elif ndim == 3: - # image.data, image.gain, image.mask in calibrated data - dims = ['train_pulse', 'slow_scan', 'fast_scan'] - else: - # Everything else seems to be 1D - dims = ['train_pulse'] - - coords = {'train_pulse': train_pulse_ids} - if modnos is not None: - dims = ['module'] + dims - coords['module'] = modnos - - arr = xarray.DataArray(data, coords=coords, dims=dims) - - if unstack_pulses: - # Separate train & pulse dimensions, and arrange dimensions - # so that the data is contiguous in memory. - if modnos is None: - dim_order = train_pulse_ids.names + dims[1:] - else: - dim_order = ['module'] + train_pulse_ids.names + dims[2:] - return arr.unstack('train_pulse').transpose(*dim_order) - else: - return arr - def _read_inner_ids(self, field='pulseId'): """Read pulse/cell IDs into a 2D array (frames, modules) @@ -654,100 +591,6 @@ def _collect_inner_ids(self, field='pulseId'): # pulse ID for each frame. return inner_ids_min - def _read_chunk(self, chunk: DataChunk, sel_frames, mod_out, roi): - """Read per-pulse data from file into an output array (of 1 module)""" - for tgt_slice, chunk_slice in self._split_align_chunk( - chunk, self.train_ids_perframe - ): - inc_pulses_chunk = sel_frames[tgt_slice] - if inc_pulses_chunk.sum() == 0: # No data from this chunk selected - continue - elif inc_pulses_chunk.all(): # All pulses in chunk - chunk.dataset.read_direct( - mod_out[tgt_slice], source_sel=(chunk_slice,) + roi - ) - continue - - # Read a subset of pulses from the chunk: - - # Reading a non-contiguous selection in HDF5 seems to be slow: - # https://forum.hdfgroup.org/t/performance-reading-data-with-non-contiguous-selection/8979 - # Except it's fast if you read the data to a matching selection in - # memory (one weird trick). - # So as a workaround, this allocates a temporary array of the same - # shape as the dataset, reads into it, and then copies the selected - # data to the output array. The extra memory copy is not optimal, - # but it's better than the HDF5 performance issue, at least in some - # realistic cases. - # N.B. tmp should only use memory for the data it contains - - # zeros() uses calloc, so the OS can do virtual memory tricks. - # Don't change this to zeros_like() ! - tmp = np.zeros(chunk.dataset.shape, chunk.dataset.dtype) - pulse_sel = np.nonzero(inc_pulses_chunk)[0] + chunk_slice.start - sel_region = (pulse_sel,) + roi - chunk.dataset.read_direct( - tmp, source_sel=sel_region, dest_sel=sel_region, - ) - # Where does this data go in the target array? - tgt_start_ix = sel_frames[:tgt_slice.start].sum() - tgt_pulse_sel = slice( - tgt_start_ix, tgt_start_ix + inc_pulses_chunk.sum() - ) - # Copy data from temp array to output array - tmp_frames_mask = np.zeros(len(tmp), dtype=np.bool_) - tmp_frames_mask[pulse_sel] = True - np.compress( - tmp_frames_mask, tmp[np.index_exp[:] + roi], - axis=0, out=mod_out[tgt_pulse_sel] - ) - - def _get_pulse_data(self, key, pulses, unstack_pulses=True, - fill_value=None, subtrain_index='pulseId', roi=(), - astype=None): - """Get a labelled array of per-pulse data (image.*) for xtdf detector""" - pulses = _check_pulse_selection(pulses) - - if isinstance(pulses, by_index): - sel_frames = self._select_pulse_indices(pulses, self.frame_counts) - pulse_ids = None - else: # by_id - pulse_ids = self._collect_inner_ids('pulseId') - sel_frames = self._select_pulse_ids(pulses, pulse_ids) - - nframes_sel = sel_frames.sum() - - eg_src = min(self.source_to_modno) - eg_keydata = self.data[eg_src, key] - - if eg_keydata.ndim >= 2 and eg_keydata.entry_shape[0] == 1: - # Ensure ROI applies to pixel dimensions, not the extra - # dim in raw data (except AGIPD, where it is data/gain) - roi = np.index_exp[:] + roi - - _roi_shape = roi_shape(eg_keydata.entry_shape, roi) - out_shape = (len(self.modno_to_source), nframes_sel) + _roi_shape - - dtype = eg_keydata.dtype if astype is None else np.dtype(astype) - out = self._out_array(out_shape, dtype, fill_value=fill_value) - - modnos = [] - for mod_ix, (modno, source) in enumerate(sorted(self.modno_to_source.items())): - for chunk in self.data._find_data_chunks(source, key): - self._read_chunk(chunk, sel_frames, out[mod_ix], roi) - - modnos.append(modno) - - if (subtrain_index == 'pulseId') and (pulse_ids is not None): - inner_ids = pulse_ids - else: - inner_ids = self._collect_inner_ids(subtrain_index) - - index = self._make_image_index( - self.train_ids_perframe, inner_ids, subtrain_index[:-2] - )[sel_frames] - - return self._guess_axes(out, index, unstack_pulses, modnos=modnos) - def get_array(self, key, pulses=np.s_[:], unstack_pulses=True, *, fill_value=None, subtrain_index='pulseId', roi=(), astype=None): @@ -789,9 +632,9 @@ def get_array(self, key, pulses=np.s_[:], unstack_pulses=True, *, roi = (roi,) if key.startswith('image.'): - return self._get_pulse_data( - key, pulses, unstack_pulses, fill_value=fill_value, - subtrain_index=subtrain_index, roi=roi, astype=astype, + return self[key].select_pulses(pulses).xarray( + fill_value=fill_value, roi=roi, subtrain_index=subtrain_index, + astype=astype, unstack_pulses=unstack_pulses, ) else: return super().get_array( @@ -926,6 +769,248 @@ def zip_trains_pulses(trains, pulses): return res + +class MultimodKeyData: + def __init__(self, det: MultimodDetectorBase, key): + self.det = det + self.key = key + self.modno_to_keydata = { + m: det.data[s, key] for (m, s) in det.modno_to_source.items() + } + + @property + def train_ids(self): + return self.det.train_ids + + def train_id_coordinates(self): + return self.det.train_ids + + @property + def modules(self): + return sorted(self.modno_to_keydata) + + @property + def _eg_keydata(self): + return self.modno_to_keydata[min(self.modno_to_keydata)] + + @property + def ndim(self): + return self._eg_keydata.ndim + 1 + + @property + def shape(self): + return ((len(self.modno_to_keydata), len(self.train_id_coordinates())) + + self._eg_keydata.entry_shape) + + @property + def dimensions(self): + return ['module', 'trainId'] + ['dim_%d' % i for i in range(self.ndim - 2)] + + def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps=False): + """Get data as a plain NumPy array with no labels""" + train_ids = np.asarray(self.det.train_ids) + + module_dim = self.det.n_modules if module_gaps else len(self.modno_to_keydata) + + out_shape = ((module_dim, len(train_ids)) + # Shape of 1 frame for 1 module with the ROI applied: + + roi_shape(self._eg_keydata.entry_shape, roi)) + + if out is None: + dtype = self._eg_keydata.dtype if astype is None else np.dtype(astype) + out = _out_array(out_shape, dtype, fill_value=fill_value) + elif out.shape != out_shape: + raise ValueError(f'requires output array of shape {out_shape}') + + for i, (modno, kd) in enumerate(sorted(self.modno_to_keydata.items())): + mod_ix = (modno - self.det._modnos_start_at) if module_gaps else i + for chunk in kd._data_chunks: + for tgt_slice, chunk_slice in self.det._split_align_chunk(chunk, train_ids): + chunk.dataset.read_direct( + out[mod_ix, tgt_slice], source_sel=(chunk_slice,) + roi + ) + return out + + def xarray(self, *, fill_value=None, roi=(), astype=None): + from xarray import DataArray + arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype) + + # Train ID index + coords = {'module': self.modules, 'trainId': self.train_id_coordinates()} + + return DataArray(arr, dims=self.dimensions, coords=coords) + + +class XtdfImageMultimodKeyData(MultimodKeyData): + _sel_frames_cached = None + + def __init__(self, det: XtdfDetectorBase, key, pulse_sel=np.s_[0:MAX_PULSES:1]): + super().__init__(det, key) + self.det = det # Makes PyCharm happy that det is XtdfDetectorBase + self._pulse_sel = pulse_sel + entry_shape = self._eg_keydata.entry_shape + self._extraneous_dim = (len(entry_shape) >= 1) and (entry_shape[0] == 1) + + @property + def ndim(self): + return super().ndim - (1 if self._extraneous_dim else 0) + + def _all_pulses(self): + psv = self._pulse_sel.value + return isinstance(psv, slice) and psv == slice(0, MAX_PULSES, 1) + + def _shape(self, module_gaps=False, roi=()): + module_dim = self.det.n_modules if module_gaps else len(self.modno_to_keydata) + nframes_sel = len(self.train_id_coordinates()) + + entry_shape = self._eg_keydata.entry_shape + if self._extraneous_dim: + entry_shape = entry_shape[1:] + + return (module_dim, nframes_sel) + roi_shape(entry_shape, roi) + + @property + def shape(self): + return self._shape() + + def train_id_coordinates(self): + # XTDF 'image' group can have >1 entry per train + a = self.det.train_ids_perframe + # Only allocate sel_frames array if we need it: + if not self._all_pulses(): + a = a[self._sel_frames] + return a + + @property + def dimensions(self): + ndim_inner = self.ndim - 2 + # TODO: this assumes we can tell what the axes are just from the + # number of dimensions. Works for the data we've seen, but we + # should look for a more reliable way. + if ndim_inner == 3: + # image.data in raw data + entry_dims = ['data_gain', 'slow_scan', 'fast_scan'] + elif ndim_inner == 2: + # image.data, image.gain, image.mask in calibrated data + entry_dims = ['slow_scan', 'fast_scan'] + else: + # Everything else seems to be 1D, but just in case + entry_dims = [f'dim_{i}' for i in range(ndim_inner)] + return ['module', 'train_pulse'] + entry_dims + + def select_pulses(self, pulses): + pulses = _check_pulse_selection(pulses) + + return XtdfImageMultimodKeyData(self.det, self.key, pulses) + + @property + def _sel_frames(self): + if self._sel_frames_cached is None: + p = self._pulse_sel + if isinstance(p, by_index): + if self._all_pulses(): + s = np.ones(len(self.det.train_ids_perframe), np.bool_) + else: + s = self.det._select_pulse_indices(p, self.det.frame_counts) + else: # by_id + pulse_ids = self.det._collect_inner_ids('pulseId') + s = _select_pulse_ids(p, pulse_ids) + self._sel_frames_cached = s + return self._sel_frames_cached + + def _read_chunk(self, chunk: DataChunk, mod_out, roi): + """Read per-pulse data from file into an output array (of 1 module)""" + for tgt_slice, chunk_slice in self.det._split_align_chunk( + chunk, self.det.train_ids_perframe + ): + inc_pulses_chunk = self._sel_frames[tgt_slice] + if inc_pulses_chunk.sum() == 0: # No data from this chunk selected + continue + elif inc_pulses_chunk.all(): # All pulses in chunk + chunk.dataset.read_direct( + mod_out[tgt_slice], source_sel=(chunk_slice,) + roi + ) + continue + + # Read a subset of pulses from the chunk: + + # Reading a non-contiguous selection in HDF5 seems to be slow: + # https://forum.hdfgroup.org/t/performance-reading-data-with-non-contiguous-selection/8979 + # Except it's fast if you read the data to a matching selection in + # memory (one weird trick). + # So as a workaround, this allocates a temporary array of the same + # shape as the dataset, reads into it, and then copies the selected + # data to the output array. The extra memory copy is not optimal, + # but it's better than the HDF5 performance issue, at least in some + # realistic cases. + # N.B. tmp should only use memory for the data it contains - + # zeros() uses calloc, so the OS can do virtual memory tricks. + # Don't change this to zeros_like() ! + tmp = np.zeros(chunk.dataset.shape, chunk.dataset.dtype) + pulse_sel = np.nonzero(inc_pulses_chunk)[0] + chunk_slice.start + sel_region = (pulse_sel,) + roi + chunk.dataset.read_direct( + tmp, source_sel=sel_region, dest_sel=sel_region, + ) + # Where does this data go in the target array? + tgt_start_ix = self._sel_frames[:tgt_slice.start].sum() + tgt_pulse_sel = slice( + tgt_start_ix, tgt_start_ix + inc_pulses_chunk.sum() + ) + # Copy data from temp array to output array + tmp_frames_mask = np.zeros(len(tmp), dtype=np.bool_) + tmp_frames_mask[pulse_sel] = True + np.compress( + tmp_frames_mask, tmp[np.index_exp[:] + roi], + axis=0, out=mod_out[tgt_pulse_sel] + ) + + def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps=False): + """Get an array of per-pulse data (image.*) for xtdf detector""" + out_shape = self._shape(module_gaps=module_gaps, roi=roi) + + if out is None: + dtype = self._eg_keydata.dtype if astype is None else np.dtype(astype) + out = _out_array(out_shape, dtype, fill_value=fill_value) + elif out.shape != out_shape: + raise ValueError(f'requires output array of shape {out_shape}') + + reading_view = out.view() + if self._extraneous_dim: + reading_view.shape = out.shape[:2] + (1,) + out.shape[2:] + # Ensure ROI applies to pixel dimensions, not the extra + # dim in raw data (except AGIPD, where it is data/gain) + roi = np.index_exp[:] + roi + print(f"{out_shape=}, {out.shape=}, {reading_view.shape=}, {self._extraneous_dim=}") + + for mod_ix, (modno, kd) in enumerate(sorted(self.modno_to_keydata.items())): + for chunk in kd._data_chunks: + self._read_chunk(chunk, reading_view[mod_ix], roi) + + return out + + def xarray(self, *, pulses=None, fill_value=None, roi=(), astype=None, + subtrain_index='pulseId', unstack_pulses=False): + arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype) + + inner_ids = self.det._collect_inner_ids(subtrain_index) + index = self.det._make_image_index( + self.det.train_ids_perframe, inner_ids, subtrain_index[:-2] + )[self._sel_frames] + + out = xarray.DataArray(arr, dims=self.dimensions, coords={ + 'train_pulse': index, 'module': self.modules, + }) + + if unstack_pulses: + # Separate train & pulse dimensions, and arrange dimensions + # so that the data is contiguous in memory. + dim_order = ['module'] + index.names + self.dimensions[2:] + return out.unstack('train_pulse').transpose(*dim_order) + + return out + + class FramesFileWriter(FileWriter): """Write selected detector frames in European XFEL HDF5 format""" def __init__(self, path, data, inc_tp_ids): @@ -1130,9 +1215,21 @@ def _get_pulse_data(self, source, key, tid): else: # ndarray data_positions = first + positions - return self.data._guess_axes( - ds[data_positions], train_pulse_ids, unstack_pulses=True - ) + data = ds[data_positions] + + # Raw files have a spurious extra dimension + if data.ndim >= 2 and data.shape[1] == 1: + data = data[:, 0] + + dims = self.data[key].dimensions[1:] # excluding 'module' dim + coords = {'train_pulse': train_pulse_ids} + + arr = xarray.DataArray(data, coords=coords, dims=dims) + + # Separate train & pulse dimensions, and arrange dimensions + # so that the data is contiguous in memory. + dim_order = train_pulse_ids.names + dims[1:] + return arr.unstack('train_pulse').transpose(*dim_order) def _select_pulse_ids(self, pulse_ids): """Select pulses by ID @@ -1429,6 +1526,7 @@ class JUNGFRAU(MultimodDetectorBase): r'(MODULE_|RECEIVER-|JNGFR)(?P\d+)' ) _main_data_key = 'data.adc' + _modnos_start_at = 1 module_shape = (512, 1024) def __init__(self, data: DataCollection, detector_name=None, modules=None, From b416dfeaf4b16d46ac034c7148cd58685b8b6f23 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 17 Aug 2022 10:37:32 +0100 Subject: [PATCH 02/16] Add .dtype, .split_trains() and .select_trains() for multi-mod KeyData --- extra_data/components.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/extra_data/components.py b/extra_data/components.py index 64e2a766..1e3b30b2 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -806,6 +806,21 @@ def shape(self): def dimensions(self): return ['module', 'trainId'] + ['dim_%d' % i for i in range(self.ndim - 2)] + @property + def dtype(self): + return self._eg_keydata.dtype + + def _with_selected_det(self, det_selected): + # Overridden for XtdfImageMultimodKeyData to preserve pulse selection + return MultimodKeyData(det_selected, self.key) + + def select_trains(self, trains): + return self._with_selected_det(self.det.select_trains(trains)) + + def split_trains(self, parts=None, trains_per_part=None, frames_per_part=None): + for det_split in self.det.split_trains(parts, trains_per_part, frames_per_part): + yield self._with_selected_det(det_split) + def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps=False): """Get data as a plain NumPy array with no labels""" train_ids = np.asarray(self.det.train_ids) @@ -844,7 +859,7 @@ def xarray(self, *, fill_value=None, roi=(), astype=None): class XtdfImageMultimodKeyData(MultimodKeyData): _sel_frames_cached = None - def __init__(self, det: XtdfDetectorBase, key, pulse_sel=np.s_[0:MAX_PULSES:1]): + def __init__(self, det: XtdfDetectorBase, key, pulse_sel=by_index[0:MAX_PULSES:1]): super().__init__(det, key) self.det = det # Makes PyCharm happy that det is XtdfDetectorBase self._pulse_sel = pulse_sel @@ -898,6 +913,10 @@ def dimensions(self): entry_dims = [f'dim_{i}' for i in range(ndim_inner)] return ['module', 'train_pulse'] + entry_dims + # Used for .select_trains() and .split_trains() + def _with_selected_det(self, det_selected): + return XtdfImageMultimodKeyData(det_selected, self.key, self._pulse_sel) + def select_pulses(self, pulses): pulses = _check_pulse_selection(pulses) From 7603a2881fe75556ce91b2251c6c6ee72fce8869 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 17 Aug 2022 16:19:21 +0100 Subject: [PATCH 03/16] Add .dask_array to multi-mod KeyData interface --- extra_data/components.py | 122 +++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 57 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index 1e3b30b2..24bc8623 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -399,19 +399,6 @@ def split_trains(self, parts=None, trains_per_part=None, frames_per_part=None): # There will always be at least the last train left to yield yield self.select_trains(np.s_[chunk_start:]) - @staticmethod - def _concat(arrays, index, fill_value, astype): - dtype = arrays[0].dtype if astype is None else np.dtype(astype) - if fill_value is None: - fill_value = np.nan if dtype.kind == 'f' else 0 - fill_value = dtype.type(fill_value) - - return xarray.concat( - [a.astype(dtype, copy=False) for a in arrays], - pd.Index(index, name='module'), - fill_value=fill_value - ) - def get_array(self, key, *, fill_value=None, roi=(), astype=None): """Get a labelled array of detector data @@ -447,14 +434,7 @@ def get_dask_array(self, key, fill_value=None, astype=None): Data type of the output array. If None (default) the dtype matches the input array dtype """ - arrays = [] - modnos = [] - for modno, source in sorted(self.modno_to_source.items()): - modnos.append(modno) - mod_arr = self.data.get_dask_array(source, key, labelled=True) - arrays.append(mod_arr) - - return self._concat(arrays, modnos, fill_value, astype) + return self[key].dask_array(labelled=True, fill_value=fill_value, astype=astype) def trains(self, require_all=True): """Iterate over trains for detector data. @@ -668,34 +648,19 @@ def get_dask_array(self, key, subtrain_index='pulseId', fill_value=None, """ if subtrain_index not in {'pulseId', 'cellId'}: raise ValueError("subtrain_index must be 'pulseId' or 'cellId'") - arrays = [] - modnos = [] - for modno, source in sorted(self.modno_to_source.items()): - modnos.append(modno) - mod_arr = self.data.get_dask_array(source, key, labelled=True) - - # At present, all the per-pulse data is stored in the 'image' key. - # If that changes, this check will need to change as well. - if key.startswith('image.'): - # Add pulse IDs to create multi-level index - inner_ix = self.data.get_array(source, 'image.' + subtrain_index) - # Raw files have a spurious extra dimension - if inner_ix.ndim >= 2 and inner_ix.shape[1] == 1: - inner_ix = inner_ix[:, 0] - - mod_arr = mod_arr.rename({'trainId': 'train_pulse'}) - - mod_arr.coords['train_pulse'] = self._make_image_index( - mod_arr.coords['train_pulse'].values, inner_ix.values, - inner_name=subtrain_index, - ).set_names('trainId', level=0) - # This uses 'trainId' where a concrete array from the same class - # uses 'train'. I didn't notice that inconsistency when I - # introduced it, and now code may be relying on each name. - - arrays.append(mod_arr) - - return self._concat(arrays, modnos, fill_value, astype) + if key.startswith('image.'): + arr = self[key].dask_array( + labelled=True, subtrain_index=subtrain_index, + fill_value=fill_value, astype=astype + ) + # Preserve the quirks of this method before refactoring + if self[key]._extraneous_dim: + arr = arr.expand_dims('tmp_name', axis=2) + renames = {'train': 'trainId', subtrain_index[:-2]: subtrain_index} + renames.update({name: f'dim_{i}' for i, name in enumerate(arr.dims[2:])}) + return arr.rename(renames) + else: + return super().get_dask_array(key, fill_value=fill_value, astype=astype) def trains(self, pulses=np.s_[:], require_all=True): """Iterate over trains for detector data. @@ -850,11 +815,31 @@ def xarray(self, *, fill_value=None, roi=(), astype=None): from xarray import DataArray arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype) - # Train ID index coords = {'module': self.modules, 'trainId': self.train_id_coordinates()} - return DataArray(arr, dims=self.dimensions, coords=coords) + def dask_array(self, *, labelled=False, fill_value=None, astype=None): + from dask.delayed import delayed + from dask.array import concatenate, from_delayed + + entry_size = (self.dtype.itemsize * + len(self.modno_to_keydata) * np.product(self._eg_keydata.entry_shape) + ) + # Aim for 1GB chunks, with an arbitrary maximum of 256 trains + split = self.split_trains(frames_per_part=min(1024 ** 3 / entry_size, 256)) + + arr = concatenate([from_delayed( + delayed(c.ndarray)(fill_value=fill_value, astype=astype), + shape=c.shape, dtype=self.dtype + ) for c in split], axis=1) + + if labelled: + from xarray import DataArray + coords = {'module': self.modules, 'trainId': self.train_id_coordinates()} + return DataArray(arr, dims=self.dimensions, coords=coords) + + return arr + class XtdfImageMultimodKeyData(MultimodKeyData): _sel_frames_cached = None @@ -1008,27 +993,50 @@ def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps return out - def xarray(self, *, pulses=None, fill_value=None, roi=(), astype=None, - subtrain_index='pulseId', unstack_pulses=False): - arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype) - + def _wrap_xarray(self, arr, subtrain_index='pulseId'): + from xarray import DataArray inner_ids = self.det._collect_inner_ids(subtrain_index) index = self.det._make_image_index( self.det.train_ids_perframe, inner_ids, subtrain_index[:-2] )[self._sel_frames] - out = xarray.DataArray(arr, dims=self.dimensions, coords={ + return DataArray(arr, dims=self.dimensions, coords={ 'train_pulse': index, 'module': self.modules, }) + def xarray(self, *, pulses=None, fill_value=None, roi=(), astype=None, + subtrain_index='pulseId', unstack_pulses=False): + arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype) + out = self._wrap_xarray(arr, subtrain_index) + if unstack_pulses: # Separate train & pulse dimensions, and arrange dimensions # so that the data is contiguous in memory. - dim_order = ['module'] + index.names + self.dimensions[2:] + dim_order = ['module'] + out.coords['train_pulse'].names + self.dimensions[2:] return out.unstack('train_pulse').transpose(*dim_order) return out + def dask_array(self, *, labelled=True, subtrain_index='pulseId', + fill_value=None, astype=None): + from dask.delayed import delayed + from dask.array import concatenate, from_delayed + + entry_size = (self.dtype.itemsize * + len(self.modno_to_keydata) * np.product(self._eg_keydata.entry_shape) + ) + # Aim for 1GB chunks, with an arbitrary maximum of 1024 frames + split = self.split_trains(frames_per_part=min(1024 ** 3 / entry_size, 1024)) + + arr = concatenate([from_delayed( + delayed(c.ndarray)(fill_value=fill_value, astype=astype), + shape=c.shape, dtype=self.dtype + ) for c in split], axis=1) + + if labelled: + return self._wrap_xarray(arr, subtrain_index) + + return arr class FramesFileWriter(FileWriter): """Write selected detector frames in European XFEL HDF5 format""" From 1d76005a73284e925b859e930ce1c00581bb2f21 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 17 Aug 2022 16:35:21 +0100 Subject: [PATCH 04/16] Remove debugging print() --- extra_data/components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/extra_data/components.py b/extra_data/components.py index 24bc8623..04ce8edb 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -985,7 +985,6 @@ def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps # Ensure ROI applies to pixel dimensions, not the extra # dim in raw data (except AGIPD, where it is data/gain) roi = np.index_exp[:] + roi - print(f"{out_shape=}, {out.shape=}, {reading_view.shape=}, {self._extraneous_dim=}") for mod_ix, (modno, kd) in enumerate(sorted(self.modno_to_keydata.items())): for chunk in kd._data_chunks: From 8e0fbd0d07b4f4fe4450dbeb3cbfa5e482b3ea8d Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 17 Aug 2022 16:44:49 +0100 Subject: [PATCH 05/16] Fix for unstacking train/pulse axis --- extra_data/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra_data/components.py b/extra_data/components.py index 04ce8edb..57abe379 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -1011,7 +1011,7 @@ def xarray(self, *, pulses=None, fill_value=None, roi=(), astype=None, if unstack_pulses: # Separate train & pulse dimensions, and arrange dimensions # so that the data is contiguous in memory. - dim_order = ['module'] + out.coords['train_pulse'].names + self.dimensions[2:] + dim_order = ['module'] + out.indexes['train_pulse'].names + self.dimensions[2:] return out.unstack('train_pulse').transpose(*dim_order) return out From 0735d0c90b27478914da06ffc433fd2876182f33 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 17 Aug 2022 17:06:09 +0100 Subject: [PATCH 06/16] Fix renaming index levels on labelled Dask array --- extra_data/components.py | 7 ++++--- extra_data/tests/conftest.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index 57abe379..930987f3 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -656,9 +656,10 @@ def get_dask_array(self, key, subtrain_index='pulseId', fill_value=None, # Preserve the quirks of this method before refactoring if self[key]._extraneous_dim: arr = arr.expand_dims('tmp_name', axis=2) - renames = {'train': 'trainId', subtrain_index[:-2]: subtrain_index} - renames.update({name: f'dim_{i}' for i, name in enumerate(arr.dims[2:])}) - return arr.rename(renames) + arr.coords['train_pulse'] = arr.indexes['train_pulse'].rename( + ['trainId', subtrain_index] + ) + return arr.rename({name: f'dim_{i}' for i, name in enumerate(arr.dims[2:])}) else: return super().get_dask_array(key, fill_value=fill_value, astype=astype) diff --git a/extra_data/tests/conftest.py b/extra_data/tests/conftest.py index 8e678952..6ecc7451 100644 --- a/extra_data/tests/conftest.py +++ b/extra_data/tests/conftest.py @@ -8,7 +8,7 @@ from . import make_examples -@pytest.fixture(scope='session', params=['0.5', '1.0']) +@pytest.fixture(scope='session', params=['1.0']) def format_version(request): return request.param From 7d3c5051a6bc14a49bc3625d5cd1510cba88a28a Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 17 Aug 2022 18:04:00 +0100 Subject: [PATCH 07/16] fixup! Fix renaming index levels on labelled Dask array --- extra_data/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra_data/tests/conftest.py b/extra_data/tests/conftest.py index 6ecc7451..8e678952 100644 --- a/extra_data/tests/conftest.py +++ b/extra_data/tests/conftest.py @@ -8,7 +8,7 @@ from . import make_examples -@pytest.fixture(scope='session', params=['1.0']) +@pytest.fixture(scope='session', params=['0.5', '1.0']) def format_version(request): return request.param From d750db7dc7effa518dbfd538fe0442ef2588daae Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 17 Aug 2022 18:04:16 +0100 Subject: [PATCH 08/16] Allow setting frames_per_chunk for dask_array() --- extra_data/components.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index 930987f3..cb997b98 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -1018,15 +1018,17 @@ def xarray(self, *, pulses=None, fill_value=None, roi=(), astype=None, return out def dask_array(self, *, labelled=True, subtrain_index='pulseId', - fill_value=None, astype=None): + fill_value=None, astype=None, frames_per_chunk=None): from dask.delayed import delayed from dask.array import concatenate, from_delayed entry_size = (self.dtype.itemsize * len(self.modno_to_keydata) * np.product(self._eg_keydata.entry_shape) ) - # Aim for 1GB chunks, with an arbitrary maximum of 1024 frames - split = self.split_trains(frames_per_part=min(1024 ** 3 / entry_size, 1024)) + if frames_per_chunk is None: + # Aim for 2GB chunks, with an arbitrary maximum of 1024 frames + frames_per_chunk = min(2 * 1024 ** 3 / entry_size, 1024) + split = self.split_trains(frames_per_part=frames_per_chunk) arr = concatenate([from_delayed( delayed(c.ndarray)(fill_value=fill_value, astype=astype), From a138efa3a9a0473cd186e2d9e8f63f77f008f799 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Thu, 18 Aug 2022 11:33:55 +0100 Subject: [PATCH 09/16] Fix renaming index levels for labelled Dask array --- extra_data/components.py | 6 +++--- extra_data/tests/test_components.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index cb997b98..27ffafef 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -656,9 +656,9 @@ def get_dask_array(self, key, subtrain_index='pulseId', fill_value=None, # Preserve the quirks of this method before refactoring if self[key]._extraneous_dim: arr = arr.expand_dims('tmp_name', axis=2) - arr.coords['train_pulse'] = arr.indexes['train_pulse'].rename( - ['trainId', subtrain_index] - ) + arr.coords['train_pulse'] = arr.indexes['train_pulse'].rename({ + 'train': 'trainId', subtrain_index[:-2]: subtrain_index + }) return arr.rename({name: f'dim_{i}' for i, name in enumerate(arr.dims[2:])}) else: return super().get_dask_array(key, fill_value=fill_value, astype=astype) diff --git a/extra_data/tests/test_components.py b/extra_data/tests/test_components.py index b737a11c..26711d59 100644 --- a/extra_data/tests/test_components.py +++ b/extra_data/tests/test_components.py @@ -249,6 +249,7 @@ def test_get_dask_array_lpd_parallelgain(mock_lpd_parallelgain_run): assert det.detector_name == 'FXE_DET_LPD1M-1' arr = det.get_dask_array('image.data') + print(arr) assert arr.shape == (16, 2 * 3 * 100, 1, 256, 256) assert arr.dims[:2] == ('module', 'train_pulse') np.testing.assert_array_equal(arr.coords['pulseId'], np.tile(np.arange(100), 6)) From 3c125f68e7465f7fd454b0ca0c0581e5edddb939 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Thu, 18 Aug 2022 11:35:01 +0100 Subject: [PATCH 10/16] Default to unlablled array for .dask_array() for consistency --- extra_data/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra_data/components.py b/extra_data/components.py index 27ffafef..c1bd8dbd 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -1017,7 +1017,7 @@ def xarray(self, *, pulses=None, fill_value=None, roi=(), astype=None, return out - def dask_array(self, *, labelled=True, subtrain_index='pulseId', + def dask_array(self, *, labelled=False, subtrain_index='pulseId', fill_value=None, astype=None, frames_per_chunk=None): from dask.delayed import delayed from dask.array import concatenate, from_delayed From c510a2d1dd914948aaf105661ba231fa72f2891b Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Thu, 18 Aug 2022 11:36:57 +0100 Subject: [PATCH 11/16] Only import xarray where it's used --- extra_data/components.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index c1bd8dbd..48f60ae8 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd -import xarray from .exceptions import SourceNameError from .reader import DataCollection, by_id, by_index @@ -1180,6 +1179,7 @@ def _get_slow_data(self, source, key, tid): for the train id tid - train id dimension is kept indexing frames within tid. """ + from xarray import DataArray file, pos, ds = self._find_data(source, key, tid) if file is None: return None @@ -1188,9 +1188,9 @@ def _get_slow_data(self, source, key, tid): firsts, counts = file.get_index(source, group) first, count = firsts[pos], counts[pos] if count == 1: - return xarray.DataArray(ds[first]) + return DataArray(ds[first]) else: - return xarray.DataArray(ds[first : first + count]) + return DataArray(ds[first : first + count]) def _get_pulse_data(self, source, key, tid): """ @@ -1212,6 +1212,7 @@ def _get_pulse_data(self, source, key, tid): xarray.DataArray Array of selected per pulse data. """ + from xarray import DataArray file, pos, ds = self._find_data(source, key, tid) if file is None: return None @@ -1253,7 +1254,7 @@ def _get_pulse_data(self, source, key, tid): dims = self.data[key].dimensions[1:] # excluding 'module' dim coords = {'train_pulse': train_pulse_ids} - arr = xarray.DataArray(data, coords=coords, dims=dims) + arr = DataArray(data, coords=coords, dims=dims) # Separate train & pulse dimensions, and arrange dimensions # so that the data is contiguous in memory. @@ -1312,6 +1313,7 @@ def _assemble_data(self, tid): xarray Assembled data array. """ + import xarray key_module_arrays = {} for modno, source in sorted(self.data.modno_to_source.items()): From fd4ad9afcf52601a6653d9a64db02396dde71996 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Thu, 18 Aug 2022 16:22:43 +0100 Subject: [PATCH 12/16] Another try at fixing up xarray wrapper --- extra_data/components.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index 48f60ae8..37174650 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -645,6 +645,7 @@ def get_dask_array(self, key, subtrain_index='pulseId', fill_value=None, data type of the output array. If None (default) the dtype matches the input array dtype """ + from xarray import DataArray if subtrain_index not in {'pulseId', 'cellId'}: raise ValueError("subtrain_index must be 'pulseId' or 'cellId'") if key.startswith('image.'): @@ -655,10 +656,13 @@ def get_dask_array(self, key, subtrain_index='pulseId', fill_value=None, # Preserve the quirks of this method before refactoring if self[key]._extraneous_dim: arr = arr.expand_dims('tmp_name', axis=2) - arr.coords['train_pulse'] = arr.indexes['train_pulse'].rename({ - 'train': 'trainId', subtrain_index[:-2]: subtrain_index + frame_idx = arr.indexes['train_pulse'].set_names( + ['trainId', subtrain_index], level=[0, -1] + ) + dims = ['module', 'train_pulse'] + [f'dim_{i}' for i in range(arr.ndim - 2)] + return DataArray(arr.data, dims=dims, coords={ + 'train_pulse': frame_idx, 'module': arr.indexes['module'], }) - return arr.rename({name: f'dim_{i}' for i, name in enumerate(arr.dims[2:])}) else: return super().get_dask_array(key, fill_value=fill_value, astype=astype) From 8c1b9228aa57904b670692d7465a620178814738 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Fri, 28 Oct 2022 16:07:24 +0100 Subject: [PATCH 13/16] Remove debugging print() --- extra_data/tests/test_components.py | 1 - 1 file changed, 1 deletion(-) diff --git a/extra_data/tests/test_components.py b/extra_data/tests/test_components.py index 26711d59..b737a11c 100644 --- a/extra_data/tests/test_components.py +++ b/extra_data/tests/test_components.py @@ -249,7 +249,6 @@ def test_get_dask_array_lpd_parallelgain(mock_lpd_parallelgain_run): assert det.detector_name == 'FXE_DET_LPD1M-1' arr = det.get_dask_array('image.data') - print(arr) assert arr.shape == (16, 2 * 3 * 100, 1, 256, 256) assert arr.dims[:2] == ('module', 'train_pulse') np.testing.assert_array_equal(arr.coords['pulseId'], np.tile(np.arange(100), 6)) From 336e9f6d8f0d9be46ba8852e8860dd2d98eb609a Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Mon, 30 Jan 2023 16:22:51 +0000 Subject: [PATCH 14/16] More careful checking of pulse selection type --- extra_data/components.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index 37174650..55a248bb 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -920,9 +920,11 @@ def _sel_frames(self): s = np.ones(len(self.det.train_ids_perframe), np.bool_) else: s = self.det._select_pulse_indices(p, self.det.frame_counts) - else: # by_id + elif isinstance(p, by_id): pulse_ids = self.det._collect_inner_ids('pulseId') s = _select_pulse_ids(p, pulse_ids) + else: + raise TypeError(f"Pulse selection should not be {type(p)}") self._sel_frames_cached = s return self._sel_frames_cached @@ -1234,8 +1236,10 @@ def _get_pulse_data(self, source, key, tid): if isinstance(self.pulses, by_id): positions = self._select_pulse_ids(pulse_ids) - else: # by_index + elif isinstance(self.pulses, by_index): positions = self._select_pulse_indices(count) + else: + raise TypeError(f"Pulse selection should not be {type(self.pulses)}") pulse_ids = pulse_ids[positions] train_ids = np.array([tid] * len(pulse_ids), dtype=np.uint64) train_pulse_ids = self.data._make_image_index(train_ids, pulse_ids) From 9bd8f080443d023ee9419f770a4420a3203dea42 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Mon, 30 Jan 2023 16:31:06 +0000 Subject: [PATCH 15/16] Rename _shape -> buffer_shape --- extra_data/components.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/extra_data/components.py b/extra_data/components.py index 55a248bb..c7eba55e 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -863,7 +863,13 @@ def _all_pulses(self): psv = self._pulse_sel.value return isinstance(psv, slice) and psv == slice(0, MAX_PULSES, 1) - def _shape(self, module_gaps=False, roi=()): + def buffer_shape(self, module_gaps=False, roi=()): + """Get the array shape for this data + + If *module_gaps* is True, include space for modules which are missing + from the data. *roi* may be a tuple of slices defining a region of + interest on the inner dimensions of the data. + """ module_dim = self.det.n_modules if module_gaps else len(self.modno_to_keydata) nframes_sel = len(self.train_id_coordinates()) @@ -875,7 +881,7 @@ def _shape(self, module_gaps=False, roi=()): @property def shape(self): - return self._shape() + return self.buffer_shape() def train_id_coordinates(self): # XTDF 'image' group can have >1 entry per train @@ -977,7 +983,7 @@ def _read_chunk(self, chunk: DataChunk, mod_out, roi): def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps=False): """Get an array of per-pulse data (image.*) for xtdf detector""" - out_shape = self._shape(module_gaps=module_gaps, roi=roi) + out_shape = self.buffer_shape(module_gaps=module_gaps, roi=roi) if out is None: dtype = self._eg_keydata.dtype if astype is None else np.dtype(astype) From 639dc2b36b14611db3a4b31f3acc53d0704094c8 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Tue, 31 Jan 2023 13:17:05 +0000 Subject: [PATCH 16/16] Better way to mark type of attribute --- extra_data/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extra_data/components.py b/extra_data/components.py index c7eba55e..9482592b 100644 --- a/extra_data/components.py +++ b/extra_data/components.py @@ -847,10 +847,10 @@ def dask_array(self, *, labelled=False, fill_value=None, astype=None): class XtdfImageMultimodKeyData(MultimodKeyData): _sel_frames_cached = None + det: XtdfDetectorBase def __init__(self, det: XtdfDetectorBase, key, pulse_sel=by_index[0:MAX_PULSES:1]): super().__init__(det, key) - self.det = det # Makes PyCharm happy that det is XtdfDetectorBase self._pulse_sel = pulse_sel entry_shape = self._eg_keydata.entry_shape self._extraneous_dim = (len(entry_shape) >= 1) and (entry_shape[0] == 1)