From d5b588dd9a3520081198ffb9f918505319dac651 Mon Sep 17 00:00:00 2001 From: Brendan Moloney Date: Sat, 1 Aug 2020 15:54:10 -0700 Subject: [PATCH 1/7] WIP: Initial meta summary work --- nibabel/metasum.py | 464 ++++++++++++++++++++++++++++++++++ nibabel/tests/test_metasum.py | 0 2 files changed, 464 insertions(+) create mode 100644 nibabel/metasum.py create mode 100644 nibabel/tests/test_metasum.py diff --git a/nibabel/metasum.py b/nibabel/metasum.py new file mode 100644 index 0000000000..9dc5dfe5af --- /dev/null +++ b/nibabel/metasum.py @@ -0,0 +1,464 @@ +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +# +# See COPYING file distributed along with the NiBabel package for the +# copyright and license terms. +# +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +'''Aggregate information for mutliple images +''' +from bitarray import bitarray, frozenbitarray +from bitarry.utils import zeroes + + +class FloatCanon(object): + '''Look up a canonical float that we compare equal to''' + def __init__(self, n_digits=6): + self._n_digits = n_digits + self._offset = 0.5 * (10 ** -n_digits) + self._canon_vals = set() + self._rounded = {} + + def get(self, val): + '''Get a canonical value that at least compares equal to `val`''' + res = self._values.get(val) + if res is not None: + return res + lb = round(val, self._n_digits) + res = self._rounded.get(lb) + if res is not None: + return res + ub = round(val + self._offset, self._n_digits) + res = self._rounded.get(ub) + if res is not None: + return res + + +_NoValue = object() + +# TODO: Integrate some value canonicalization filtering? Or just require the +# user to do that themselves? +class ValueIndices(object): + """Track indices of values in sequence. + + If values repeat frequently then memory usage can be dramatically improved. + It can be thought of as the inverse to a list. + + >>> values = ['a', 'a', 'b', 'a', 'b'] + >>> vidx = ValueIndices(values) + >>> vidx['a'] + [0, 1, 3] + >>> vidx['b'] + [2, 4] + """ + + def __init__(self, values=None): + """Initialize a ValueIndices instance. + + Parameters + ---------- + values : sequence + The sequence of values to track indices on + """ + + self._n_input = 0 + + # The values can be constant, unique to specific indices, or + # arbitrarily varying + self._const_val = _NoValue + self._unique_vals = {} + self._val_bitarrs = {} + + if values is not None: + self.extend(values) + + @property + def n_input(self): + '''The number of inputs we are indexing''' + return self._n_input + + def __len__(self): + '''Number of unique values being tracked''' + if self._const_val is not _NoValue: + return 1 + return len(self._unique_vals) + len(self._val_bitarrs) + + def __getitem__(self, value): + '''Return list of indices for the given value''' + if self._const_val == value: + return list(range(self._n_input)) + idx = self._unique_vals.get(value) + if idx is not None: + return [idx] + ba = self._val_bitarrs[value] + return list(self._extract_indices(ba)) + + def values(self): + '''Generate each unique value that has been seen''' + if self._const_val is not _NoValue: + yield self._const_val + return + for val in self._unique_vals.keys(): + yield val + for val in self._val_bitarrs.keys(): + yield val + + def get_mask(self, value): + '''Get bitarray mask of indices with this value''' + if self._const_val is not _NoValue: + if self._const_val != value: + raise KeyError() + res = bitarray(self._n_input) + res.setall(1) + return res + idx = self._unique_vals.get(value) + if idx is not None: + res = zeroes(self._n_inpuf) + res[idx] = 1 + return res + return self._val_bitarrs[value].copy() + + def num_indices(self, value): + '''Number of indices for the given `value`''' + if self._const_val is not _NoValue: + if self._const_val != value: + raise KeyError() + return self._n_input + if value in self._unique_vals: + return 1 + return self._val_bitarrs[value].count() + + def get_value(self, idx): + '''Get the value at `idx`''' + if not 0 <= idx < self._n_input: + raise IndexError() + if self._const_val is not _NoValue: + return self._const_val + for val, vidx in self._unique_vals.items(): + if vidx == idx: + return val + bit_idx = zeroes(self._n_input) + bit_idx[idx] = 1 + for val, ba in self._val_bitarrs.items(): + if (ba | bit_idx).any(): + return val + assert False + + def extend(self, values): + '''Add more values to the end of any existing ones''' + curr_size = self._n_input + if isinstance(values, ValueIndices): + other_is_vi = True + other_size = values._n_input + else: + other_is_vi = False + other_size = len(values) + final_size = curr_size + other_size + for ba in self._val_bitarrs.values(): + ba.extend(zeroes(other_size)) + if other_is_vi: + if self._const_val is not _NoValue: + if values._const_val is not _NoValue: + self._extend_const(values) + return + else: + self._rm_const() + elif values._const_val is not _NoValue: + cval = values._const_val + other_unique = {} + other_bitarrs = {} + if values._n_input == 1: + other_unique[cval] = 0 + else: + other_bitarrs[cval] = bitarray(values._n_input) + other_bitarrs[cval].setall(1) + else: + other_unique = values._unique_vals + other_bitarrs = values._val_bitarrs + for val, other_idx in other_unique.items(): + self._ingest_single(val, final_size, curr_size, other_idx) + for val, other_ba in other_bitarrs.items(): + curr_ba = self._val_bitarrs.get(val) + if curr_ba is None: + curr_idx = self._unique_vals.get(val) + if curr_idx is None: + if curr_size == 0: + new_ba = other_ba.copy() + else: + new_ba = zeroes(curr_size) + new_ba.extend(other_ba) + else: + new_ba = zeroes(curr_size) + new_ba[curr_idx] = True + new_ba.extend(other_ba) + del self._unique_vals[val] + self._val_bitarrs[val] = new_ba + else: + curr_ba[curr_size:] |= other_ba + else: + for other_idx, val in enumerate(values): + self._ingest_single(val, final_size, curr_size, other_idx) + self._n_input = final_size + + def append(self, value): + '''Append another value as input''' + if self._const_val == value: + self._n_input += 1 + return + elif self._const_val is not _NoValue: + self._rm_const() + curr_size = self._n_input + found = False + for val, bitarr in self._val_bitarrs.items(): + if val == value: + found = True + bitarr.append(True) + else: + bitarr.append(False) + if not found: + curr_idx = self._unique_vals.get(value) + if curr_idx is None: + self._unique_vals[value] = curr_size + else: + new_ba = zeroes(curr_size + 1) + new_ba[curr_idx] = True + new_ba[curr_size] = True + self._val_bitarrs[value] = new_ba + del self._unique_vals[value] + self._n_input += 1 + + def argsort(self, reverse=False): + '''Return array of indices in order that sorts the values''' + if self._const_val is not _NoValue: + return np.arange(self._n_input) + res = np.empty(self._n_input, dtype=np.int64) + vals = list(self._unique_vals.keys()) + list(self._val_bitarrs.keys()) + vals.sort(reverse=reverse) + res_idx = 0 + for val in vals: + idx = self._unique_vals.get(val) + if idx is not None: + res[res_idx] = idx + res_idx += 1 + continue + ba = self._val_bitarrs[val] + for idx in self._extract_indices(ba): + res[res_idx] = idx + res_idx += 1 + return res + + def is_covariant(self, other): + '''True if `other` has values that vary the same way ours do + + The actual values themselves are ignored + ''' + if self._n_input != other._n_input or len(self) != len(other): + return False + if self._const_val is not _NoValue: + return other._const_val is not _NoValue + if self._n_input == len(self): + return other._n_input == len(other) + self_ba_set = set(frozenbitarray(ba) for ba in self._val_bitarrs.values()) + other_ba_set = set(frozenbitarray(ba) for ba in other._val_bitarrs.values()) + if self_ba_set != other_ba_set: + return False + if len(self._unique_vals) != len(other._unique_vals): + return False + return True + + def is_blocked(self, block_factor=None): + '''True if each value has the same number of indices + + If `block_factor` is not None we also test that it evenly divides the + block size. + ''' + block_size, rem = divmod(self._n_input, len(self)) + if rem != 0: + return False + if block_factor is not None and block_size % block_factor != 0: + return False + for val in self.values(): + if self.num_indices(val) != block_size: + return False + return True + + def is_subpartition(self, other): + '''True if we have more values and they nest within values from other + + + ''' + + def _extract_indices(self, ba): + '''Generate integer indices from bitarray representation''' + start = 0 + while True: + try: + # TODO: Is this the most efficient approach? + curr_idx = ba.index(True, start=start) + except ValueError: + return + yield curr_idx + start = curr_idx + + def _ingest_single(self, val, final_size, curr_size, other_idx): + '''Helper to ingest single value from another collection''' + curr_ba = self._val_bitarrs.get(val) + if curr_ba is None: + curr_idx = self._unique_vals.get(val) + if curr_idx is None: + self._unique_vals[val] = curr_size + other_idx + else: + new_ba = zeroes(final_size) + new_ba[curr_idx] = True + new_ba[curr_size + other_idx] = True + self._val_bitarrs = new_ba + del self._unique_vals[val] + else: + curr_ba[curr_size + other_idx] = True + + def _rm_const(self): + assert self._const_val is not _NoValue + if self._n_input == 1: + self._unique_vals[self._const_val] = 0 + else: + self._val_bitarrs[self._const_val] = bitarray(self._n_input) + self._val_bitarrs[self._const_val].setall(1) + self._const_val == _NoValue + + def _extend_const(self, other): + if self._const_val != other._const_val: + if self._n_input == 1: + self._unique_vals[self._const_val] = 0 + else: + self_ba = bitarray(self._n_input) + other_ba = bitarray(other._n_input) + self_ba.setall(1) + other_ba.setall(0) + self._val_bitarrs[self._const_val] = self_ba + other_ba + if other._n_input == 1: + self._unique_vals[other._const_val] = self._n_input + else: + self_ba = bitarray(self._n_input) + other_ba = bitarray(other._n_input) + self_ba.setall(0) + other_ba.setall(1) + self._val_bitarrs[other._const_val] = self_ba + other_ba + self._const_val = _NoValue + self._n_input += other._n_input + + +_MissingKey = object() + + +class MetaSummary: + '''Summarize a sequence of dicts, tracking how individual keys vary + + The assumption is that for any key many values will be constant, or at + least repeated, and thus we can reduce memory consumption by only storing + the value once along with the indices it appears at. + ''' + def __init__(self): + self._v_idxs = {} + self._n_input = 0 + + @property + def n_input(self): + return self._n_input + + def append(self, meta): + seen = set() + for key, v_idx in self._v_idxs.items(): + val = meta.get(key, _MissingKey) + v_idx.append(val) + seen.add(key) + for key, val in meta.items(): + if key in seen: + continue + v_idx = ValueIndices([_MissingKey for _ in range(self._n_input)]) + v_idx.append(val) + self._v_idxs[key] = v_idx + self._n_input += 1 + + def extend(self, metas): + pass # TODO + + def keys(self): + '''Generate all known keys''' + return self._v_idxs.keys() + + def const_keys(self): + '''Generate keys with a constant value across all inputs''' + for key, v_idx in self._v_idxs.items(): + if len(v_idx) == 1: + yield key + + def unique_keys(self): + '''Generate keys with a unique value in each input''' + n_input = self._n_input + if n_input <= 1: + return + for key, v_idx in self._v_idxs.items(): + if len(v_idx) == n_input: + yield key + + def repeating_keys(self): + '''Generate keys that have some repeating component but are not const + ''' + n_input = self._n_input + if n_input <= 1: + return + for key, v_idx in self._v_idxs.items(): + if 1 < len(v_idx) < n_input: + yield key + + def repeating_groups(self, block_only=False, block_factor=None): + '''Generate groups of repeating keys that vary with the same pattern + ''' + n_input = self._n_input + if n_input <= 1: + # If there is only one element, consider all keys as const + return + # TODO: Can we sort so grouped v_idxs are sequential? + # - Sort by num values isn't sufficient + curr_group = [] + for key, v_idx in self._v_idxs.items(): + if 1 < len(v_idx) < n_input: + if v_idx.is_even(block_factor): + pass # TODO + + def get_meta(self, idx): + '''Get the full dict at the given index''' + res = {} + for key, v_idx in self._v_idxs.items(): + val = v_idx.get_value(idx) + if val is _MissingKey: + continue + res[key] = val + return res + + def get_val(self, idx, key, default=None): + '''Get the value at `idx` for the `key`, or return `default``''' + res = self._v_idxs[key].get_value(key) + if res is _MissingKey: + return default + return res + + def nd_sort(self, dim_keys=None): + '''Produce indices ordered so as to fill an n-D array''' + +class SummaryTree: + '''Groups incoming meta data and creates hierarchy of related groups + + Each leaf node in the tree is a `MetaSummary` + ''' + def __init__(self, group_keys): + self._group_keys = group_keys + self._group_summaries= {} + + def add(self, meta): + pass + + def groups(self): + '''Generate the groups and their meta summaries''' + diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py new file mode 100644 index 0000000000..e69de29bb2 From cb3222bb1b489a17296357a115c27d44fbbd5d61 Mon Sep 17 00:00:00 2001 From: Brendan Moloney Date: Fri, 9 Jul 2021 10:00:06 -0700 Subject: [PATCH 2/7] WIP: Basics mostly working, needs more testing and finish ndSort --- nibabel/metasum.py | 245 +++++++++++++++++++++++++--------- nibabel/tests/test_metasum.py | 63 +++++++++ 2 files changed, 245 insertions(+), 63 deletions(-) diff --git a/nibabel/metasum.py b/nibabel/metasum.py index 9dc5dfe5af..7daeb8ac04 100644 --- a/nibabel/metasum.py +++ b/nibabel/metasum.py @@ -6,14 +6,18 @@ # copyright and license terms. # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## -'''Aggregate information for mutliple images +'''Memory efficient tracking of meta data dicts with repeating elements ''' +from dataclasses import dataclass +from enum import IntEnum + from bitarray import bitarray, frozenbitarray -from bitarry.utils import zeroes +from bitarray.util import zeros -class FloatCanon(object): +class FloatCanon: '''Look up a canonical float that we compare equal to''' + def __init__(self, n_digits=6): self._n_digits = n_digits self._offset = 0.5 * (10 ** -n_digits) @@ -39,7 +43,9 @@ def get(self, val): # TODO: Integrate some value canonicalization filtering? Or just require the # user to do that themselves? -class ValueIndices(object): + + +class ValueIndices: """Track indices of values in sequence. If values repeat frequently then memory usage can be dramatically improved. @@ -114,19 +120,31 @@ def get_mask(self, value): return res idx = self._unique_vals.get(value) if idx is not None: - res = zeroes(self._n_inpuf) + res = zeros(self._n_inpuf) res[idx] = 1 return res return self._val_bitarrs[value].copy() - def num_indices(self, value): + def num_indices(self, value, mask=None): '''Number of indices for the given `value`''' + if mask is not None: + if len(mask) != self.n_input: + raise ValueError("Mask length must match input length") if self._const_val is not _NoValue: if self._const_val != value: raise KeyError() - return self._n_input - if value in self._unique_vals: + if mask is None: + return self._n_input + return mask.count() + unique_idx = self._unique_vals.get(_NoValue) + if unique_idx is not _NoValue: + if mask is not None: + if mask[unique_idx]: + return 1 + return 0 return 1 + if mask is not None: + return (self._val_bitarrs[value] & mask).count return self._val_bitarrs[value].count() def get_value(self, idx): @@ -138,13 +156,17 @@ def get_value(self, idx): for val, vidx in self._unique_vals.items(): if vidx == idx: return val - bit_idx = zeroes(self._n_input) + bit_idx = zeros(self._n_input) bit_idx[idx] = 1 for val, ba in self._val_bitarrs.items(): - if (ba | bit_idx).any(): + if (ba & bit_idx).any(): return val assert False + def to_list(self): + '''Convert back to a list of values''' + return [self.get_value(i) for i in range(self.n_input)] + def extend(self, values): '''Add more values to the end of any existing ones''' curr_size = self._n_input @@ -156,7 +178,7 @@ def extend(self, values): other_size = len(values) final_size = curr_size + other_size for ba in self._val_bitarrs.values(): - ba.extend(zeroes(other_size)) + ba.extend(zeros(other_size)) if other_is_vi: if self._const_val is not _NoValue: if values._const_val is not _NoValue: @@ -186,10 +208,10 @@ def extend(self, values): if curr_size == 0: new_ba = other_ba.copy() else: - new_ba = zeroes(curr_size) + new_ba = zeros(curr_size) new_ba.extend(other_ba) else: - new_ba = zeroes(curr_size) + new_ba = zeros(curr_size) new_ba[curr_idx] = True new_ba.extend(other_ba) del self._unique_vals[val] @@ -221,13 +243,20 @@ def append(self, value): if curr_idx is None: self._unique_vals[value] = curr_size else: - new_ba = zeroes(curr_size + 1) + new_ba = zeros(curr_size + 1) new_ba[curr_idx] = True new_ba[curr_size] = True self._val_bitarrs[value] = new_ba del self._unique_vals[value] self._n_input += 1 + def reverse(self): + '''Reverse the indices in place''' + for val, idx in self._unique_vals.items(): + self._unique_vals[val] = self._n_input - idx - 1 + for val, bitarr in self._val_bitarrs.items(): + bitarr.reverse() + def argsort(self, reverse=False): '''Return array of indices in order that sorts the values''' if self._const_val is not _NoValue: @@ -248,6 +277,18 @@ def argsort(self, reverse=False): res_idx += 1 return res + def reorder(self, order): + '''Reorder the indices in place''' + if len(order) != self._n_input: + raise ValueError("The 'order' has the incorrect length") + for val, idx in self._unique_vals.items(): + self._unique_vals[val] = order.index(idx) + for val, bitarr in self._val_bitarrs.items(): + new_ba = zeros(self._n_input) + for idx in self._extract_indices(bitarr): + new_ba[order.index(idx)] = True + self._val_bitarrs[val] = new_ba + def is_covariant(self, other): '''True if `other` has values that vary the same way ours do @@ -267,27 +308,22 @@ def is_covariant(self, other): return False return True - def is_blocked(self, block_factor=None): - '''True if each value has the same number of indices + def get_block_size(self): + '''Return size of even blocks of values, or None if values aren't "blocked" - If `block_factor` is not None we also test that it evenly divides the - block size. + The number of values must evenly divide the number of inputs into the block size, + with each value appearing that same number of times. ''' block_size, rem = divmod(self._n_input, len(self)) if rem != 0: - return False - if block_factor is not None and block_size % block_factor != 0: - return False + return None for val in self.values(): if self.num_indices(val) != block_size: - return False - return True + return None + return block_size def is_subpartition(self, other): - '''True if we have more values and they nest within values from other - - - ''' + '''''' def _extract_indices(self, ba): '''Generate integer indices from bitarray representation''' @@ -295,7 +331,7 @@ def _extract_indices(self, ba): while True: try: # TODO: Is this the most efficient approach? - curr_idx = ba.index(True, start=start) + curr_idx = ba.index(True, start) except ValueError: return yield curr_idx @@ -309,10 +345,10 @@ def _ingest_single(self, val, final_size, curr_size, other_idx): if curr_idx is None: self._unique_vals[val] = curr_size + other_idx else: - new_ba = zeroes(final_size) + new_ba = zeros(final_size) new_ba[curr_idx] = True new_ba[curr_size + other_idx] = True - self._val_bitarrs = new_ba + self._val_bitarrs[val] = new_ba del self._unique_vals[val] else: curr_ba[curr_size + other_idx] = True @@ -351,6 +387,25 @@ def _extend_const(self, other): _MissingKey = object() +class DimTypes(IntEnum): + '''Enmerate the three types of nD dimensions''' + SLICE = 1 + TIME = 2 + PARAM = 3 + + +@dataclass +class DimIndex: + '''Specify an nD index''' + dim_type: DimTypes + + key: str + + +class NdSortError(Exception): + '''Raised when the data cannot be sorted into an nD array as specified''' + + class MetaSummary: '''Summarize a sequence of dicts, tracking how individual keys vary @@ -358,6 +413,7 @@ class MetaSummary: least repeated, and thus we can reduce memory consumption by only storing the value once along with the indices it appears at. ''' + def __init__(self): self._v_idxs = {} self._n_input = 0 @@ -380,9 +436,6 @@ def append(self, meta): self._v_idxs[key] = v_idx self._n_input += 1 - def extend(self, metas): - pass # TODO - def keys(self): '''Generate all known keys''' return self._v_idxs.keys() @@ -412,20 +465,26 @@ def repeating_keys(self): if 1 < len(v_idx) < n_input: yield key - def repeating_groups(self, block_only=False, block_factor=None): - '''Generate groups of repeating keys that vary with the same pattern + def covariant_groups(self, keys=None, block_only=False): + '''Generate groups of keys that vary with the same pattern ''' - n_input = self._n_input - if n_input <= 1: - # If there is only one element, consider all keys as const - return - # TODO: Can we sort so grouped v_idxs are sequential? - # - Sort by num values isn't sufficient - curr_group = [] - for key, v_idx in self._v_idxs.items(): - if 1 < len(v_idx) < n_input: - if v_idx.is_even(block_factor): - pass # TODO + if keys is None: + keys = self.keys() + groups = [] + for key in keys: + v_idx = self._v_idxs[key] + if len(groups) == 0: + groups.append((key, v_idx)) + continue + for group in groups: + if group[0][1].is_covariant(v_idx): + group.append(key) + break + else: + groups.append((key, v_idx)) + for group in groups: + group[0] = group[0][0] + return groups def get_meta(self, idx): '''Get the full dict at the given index''' @@ -439,26 +498,86 @@ def get_meta(self, idx): def get_val(self, idx, key, default=None): '''Get the value at `idx` for the `key`, or return `default``''' - res = self._v_idxs[key].get_value(key) + res = self._v_idxs[key].get_value(idx) if res is _MissingKey: return default return res - def nd_sort(self, dim_keys=None): - '''Produce indices ordered so as to fill an n-D array''' + def reorder(self, order): + '''Reorder indices in place''' + for v_idx in self._v_idxs.values(): + v_idx.reorder(order) -class SummaryTree: - '''Groups incoming meta data and creates hierarchy of related groups - - Each leaf node in the tree is a `MetaSummary` - ''' - def __init__(self, group_keys): - self._group_keys = group_keys - self._group_summaries= {} - - def add(self, meta): - pass - - def groups(self): - '''Generate the groups and their meta summaries''' + def nd_sort(self, dims): + '''Produce linear indices to fill nD array as specified by `dims` + Assumes each input corresponds to a 2D or 3D array, and the combined + array is 3D+ + ''' + # Make sure dims aren't completely invalid + if len(dims) == 0: + raise ValueError("At least one dimension must be specified") + last_dim = None + for dim in dims: + if last_dim is not None: + if last_dim.dim_type > dim.dim_type: + # TODO: This only allows PARAM dimensions at the end, which I guess is reasonable? + raise ValueError("Invalid dimension order") + elif last_dim.dim_type == dim.dim_type and dim.dim_type != DimTypes.PARAM: + raise ValueError("There can be at most one each of SLICE and TIME dimensions") + last_dim = dim + + # Pull out info about different types of dims + n_slices = None + n_vol = None + time_dim = None + param_dims = [] + n_params = [] + total_params = 1 + shape = [] + curr_size = 1 + for dim in dims: + dim_vidx = self._v_idxs[dim.key] + dim_type = dim.dim_type + if dim_type is DimTypes.SLICE: + n_slices = len(dim_vidx) + n_vol = dim_vidx.get_block_size() + if n_vol is None: + raise NdSortError("There are missing or extra slices") + shape.append(n_slices) + curr_size *= n_slices + elif dim_type is DimTypes.TIME: + time_dim = dim + elif dim_type is DimTypes.PARAM: + if dim_vidx.get_block_size() is None: + raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs") + param_dims.append(dim) + n_param = len(dim_vidx) + n_params.append(n_param) + total_params *= n_param + if n_vol is None: + n_vol = self._n_input + + # Size of the time dimension must be infered from the size of the other dims + n_time = 1 + if time_dim is not None: + n_time, rem = divmod(n_vol, total_params) + if rem != 0: + raise NdSortError(f"The combined parameters don't evenly divide inputs") + shape.append(n_time) + curr_size *= n_time + + # Complete the "shape", and do a more detailed check that our param dims make sense + for dim, n_param in zip(param_dims, n_params): + dim_vidx = self._v_idxs[dim.key] + if dim_vidx.get_block_size() != curr_size: + raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs") + shape.append(n_param) + curr_size *= n_param + + # Extract dim keys for each input and do the actual sort + sort_keys = [(idx, tuple(self.get_val(idx, dim.key) for dim in reversed(dims))) + for idx in range(self._n_input)] + sort_keys.sort(key=lambda x: x[1]) + + # TODO: Finish this diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py index e69de29bb2..c654e82614 100644 --- a/nibabel/tests/test_metasum.py +++ b/nibabel/tests/test_metasum.py @@ -0,0 +1,63 @@ +from ..metasum import MetaSummary, ValueIndices + +import pytest + + +vidx_test_patterns = ([0] * 8, + ([0] * 4) + ([1] * 4), + [0, 0, 1, 2, 3, 3, 3, 4], + list(range(8)), + list(range(6)) + [6] * 2, + ([0] * 2) + list(range(2, 8)), + ) + + +@pytest.mark.parametrize("in_list", vidx_test_patterns) +def test_value_indices_rt(in_list): + '''Test we can roundtrip list -> ValueIndices -> list''' + vidx = ValueIndices(in_list) + out_list = vidx.to_list() + assert in_list == out_list + + +@pytest.mark.parametrize("in_list", vidx_test_patterns) +def test_value_indices_append_extend(in_list): + '''Test that append/extend are equivalent''' + vidx_list = [ValueIndices() for _ in range(4)] + vidx_list[0].extend(in_list) + vidx_list[0].extend(in_list) + for val in in_list: + vidx_list[1].append(val) + for val in in_list: + vidx_list[1].append(val) + vidx_list[2].extend(in_list) + for val in in_list: + vidx_list[2].append(val) + for val in in_list: + vidx_list[3].append(val) + vidx_list[3].extend(in_list) + for vidx in vidx_list: + assert vidx.to_list() == in_list + in_list + + +metasum_test_dicts = (({'key1': 0, 'key2': 'a', 'key3': 3.0}, + {'key1': 2, 'key2': 'c', 'key3': 1.0}, + {'key1': 1, 'key2': 'b', 'key3': 2.0}, + ), + ({'key1': 0, 'key2': 'a', 'key3': 3.0}, + {'key1': 2, 'key2': 'c'}, + {'key1': 1, 'key2': 'b', 'key3': 2.0}, + ), + ) + + +@pytest.mark.parametrize("in_dicts", metasum_test_dicts) +def test_meta_summary_rt(in_dicts): + msum = MetaSummary() + for in_dict in in_dicts: + msum.append(in_dict) + for in_idx in range(len(in_dicts)): + out_dict = msum.get_meta(in_idx) + assert out_dict == in_dicts[in_idx] + for key, in_val in in_dicts[in_idx].items(): + assert in_val == msum.get_val(in_idx, key) From c21a8fd46014daec584d8133cf7b22acb0dcbec2 Mon Sep 17 00:00:00 2001 From: Brendan Moloney Date: Fri, 9 Jul 2021 16:32:58 -0700 Subject: [PATCH 3/7] TST+BF: Expand tests and fix bugs --- nibabel/metasum.py | 68 ++++++++++++++++++++++------------- nibabel/tests/test_metasum.py | 33 ++++++++++++----- 2 files changed, 69 insertions(+), 32 deletions(-) diff --git a/nibabel/metasum.py b/nibabel/metasum.py index 7daeb8ac04..d1e84dad55 100644 --- a/nibabel/metasum.py +++ b/nibabel/metasum.py @@ -125,7 +125,7 @@ def get_mask(self, value): return res return self._val_bitarrs[value].copy() - def num_indices(self, value, mask=None): + def count(self, value, mask=None): '''Number of indices for the given `value`''' if mask is not None: if len(mask) != self.n_input: @@ -136,7 +136,7 @@ def num_indices(self, value, mask=None): if mask is None: return self._n_input return mask.count() - unique_idx = self._unique_vals.get(_NoValue) + unique_idx = self._unique_vals.get(value, _NoValue) if unique_idx is not _NoValue: if mask is not None: if mask[unique_idx]: @@ -144,7 +144,7 @@ def num_indices(self, value, mask=None): return 0 return 1 if mask is not None: - return (self._val_bitarrs[value] & mask).count + return (self._val_bitarrs[value] & mask).count() return self._val_bitarrs[value].count() def get_value(self, idx): @@ -169,14 +169,14 @@ def to_list(self): def extend(self, values): '''Add more values to the end of any existing ones''' - curr_size = self._n_input + init_size = self._n_input if isinstance(values, ValueIndices): other_is_vi = True other_size = values._n_input else: other_is_vi = False other_size = len(values) - final_size = curr_size + other_size + final_size = init_size + other_size for ba in self._val_bitarrs.values(): ba.extend(zeros(other_size)) if other_is_vi: @@ -185,7 +185,7 @@ def extend(self, values): self._extend_const(values) return else: - self._rm_const() + self._rm_const(final_size) elif values._const_val is not _NoValue: cval = values._const_val other_unique = {} @@ -199,29 +199,30 @@ def extend(self, values): other_unique = values._unique_vals other_bitarrs = values._val_bitarrs for val, other_idx in other_unique.items(): - self._ingest_single(val, final_size, curr_size, other_idx) + self._ingest_single(val, final_size, init_size, other_idx) for val, other_ba in other_bitarrs.items(): curr_ba = self._val_bitarrs.get(val) if curr_ba is None: curr_idx = self._unique_vals.get(val) if curr_idx is None: - if curr_size == 0: + if init_size == 0: new_ba = other_ba.copy() else: - new_ba = zeros(curr_size) + new_ba = zeros(init_size) new_ba.extend(other_ba) else: - new_ba = zeros(curr_size) + new_ba = zeros(init_size) new_ba[curr_idx] = True new_ba.extend(other_ba) del self._unique_vals[val] self._val_bitarrs[val] = new_ba else: - curr_ba[curr_size:] |= other_ba + curr_ba[init_size:] |= other_ba + self._n_input += other_ba.count() else: for other_idx, val in enumerate(values): - self._ingest_single(val, final_size, curr_size, other_idx) - self._n_input = final_size + self._ingest_single(val, final_size, init_size, other_idx) + assert self._n_input == final_size def append(self, value): '''Append another value as input''' @@ -229,10 +230,18 @@ def append(self, value): self._n_input += 1 return elif self._const_val is not _NoValue: - self._rm_const() + self._rm_const(self._n_input + 1) + self._unique_vals[value] = self._n_input + self._n_input += 1 + return + if self._n_input == 0: + self._const_val = value + self._n_input += 1 + return curr_size = self._n_input found = False for val, bitarr in self._val_bitarrs.items(): + assert len(bitarr) == self._n_input if val == value: found = True bitarr.append(True) @@ -318,7 +327,7 @@ def get_block_size(self): if rem != 0: return None for val in self.values(): - if self.num_indices(val) != block_size: + if self.count(val) != block_size: return None return block_size @@ -335,32 +344,43 @@ def _extract_indices(self, ba): except ValueError: return yield curr_idx - start = curr_idx + start = curr_idx + 1 - def _ingest_single(self, val, final_size, curr_size, other_idx): + def _ingest_single(self, val, final_size, init_size, other_idx): '''Helper to ingest single value from another collection''' + if val == self._const_val: + self._n_input += 1 + return + elif self._const_val is not _NoValue: + self._rm_const(final_size) + if self._n_input == 0: + self._const_val = val + self._n_input += 1 + return + curr_ba = self._val_bitarrs.get(val) if curr_ba is None: curr_idx = self._unique_vals.get(val) if curr_idx is None: - self._unique_vals[val] = curr_size + other_idx + self._unique_vals[val] = init_size + other_idx else: new_ba = zeros(final_size) new_ba[curr_idx] = True - new_ba[curr_size + other_idx] = True + new_ba[init_size + other_idx] = True self._val_bitarrs[val] = new_ba del self._unique_vals[val] else: - curr_ba[curr_size + other_idx] = True + curr_ba[init_size + other_idx] = True + self._n_input += 1 - def _rm_const(self): + def _rm_const(self, final_size): assert self._const_val is not _NoValue if self._n_input == 1: self._unique_vals[self._const_val] = 0 else: - self._val_bitarrs[self._const_val] = bitarray(self._n_input) - self._val_bitarrs[self._const_val].setall(1) - self._const_val == _NoValue + self._val_bitarrs[self._const_val] = zeros(final_size) + self._val_bitarrs[self._const_val][:self._n_input] = True + self._const_val = _NoValue def _extend_const(self, other): if self._const_val != other._const_val: diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py index c654e82614..c0aced4d2a 100644 --- a/nibabel/tests/test_metasum.py +++ b/nibabel/tests/test_metasum.py @@ -13,9 +13,16 @@ @pytest.mark.parametrize("in_list", vidx_test_patterns) -def test_value_indices_rt(in_list): +def test_value_indices_basics(in_list): '''Test we can roundtrip list -> ValueIndices -> list''' vidx = ValueIndices(in_list) + assert vidx.n_input == len(in_list) + assert len(vidx) == len(set(in_list)) + assert sorted(vidx.values()) == sorted(list(set(in_list))) + for val in vidx.values(): + assert vidx.count(val) == in_list.count(val) + for in_idx in vidx[val]: + assert in_list[in_idx] == val out_list = vidx.to_list() assert in_list == out_list @@ -40,22 +47,32 @@ def test_value_indices_append_extend(in_list): assert vidx.to_list() == in_list + in_list -metasum_test_dicts = (({'key1': 0, 'key2': 'a', 'key3': 3.0}, - {'key1': 2, 'key2': 'c', 'key3': 1.0}, - {'key1': 1, 'key2': 'b', 'key3': 2.0}, +metasum_test_dicts = (({'u1': 0, 'u2': 'a', 'u3': 3.0, 'c1': True, 'r1': 5}, + {'u1': 2, 'u2': 'c', 'u3': 1.0, 'c1': True, 'r1': 5}, + {'u1': 1, 'u2': 'b', 'u3': 2.0, 'c1': True, 'r1': 7}, ), - ({'key1': 0, 'key2': 'a', 'key3': 3.0}, - {'key1': 2, 'key2': 'c'}, - {'key1': 1, 'key2': 'b', 'key3': 2.0}, + ({'u1': 0, 'u2': 'a', 'u3': 3.0, 'c1': True, 'r1': 5}, + {'u1': 2, 'u2': 'c', 'c1': True, 'r1': 5}, + {'u1': 1, 'u2': 'b', 'u3': 2.0, 'c1': True}, ), ) @pytest.mark.parametrize("in_dicts", metasum_test_dicts) -def test_meta_summary_rt(in_dicts): +def test_meta_summary_basics(in_dicts): msum = MetaSummary() + all_keys = set() for in_dict in in_dicts: msum.append(in_dict) + for key in in_dict.keys(): + all_keys.add(key) + assert all_keys == set(msum.keys()) + for key in msum.const_keys(): + assert key.startswith('c') + for key in msum.unique_keys(): + assert key.startswith('u') + for key in msum.repeating_keys(): + assert key.startswith('r') for in_idx in range(len(in_dicts)): out_dict = msum.get_meta(in_idx) assert out_dict == in_dicts[in_idx] From 4ba6a733b0e4e711914f2acfe32ec0928880dc3d Mon Sep 17 00:00:00 2001 From: Brendan Moloney Date: Fri, 9 Jul 2021 16:35:57 -0700 Subject: [PATCH 4/7] BF: Add bitarray dependency --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 85aebfee7d..23ac6fce0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,7 @@ python_requires = >=3.6 install_requires = numpy >=1.13 packaging >=14.3 + bitarray zip_safe = False packages = find: From 583e0aa551cf1c790e5d299180356cdebbe46b22 Mon Sep 17 00:00:00 2001 From: Brendan Moloney Date: Fri, 9 Jul 2021 16:58:24 -0700 Subject: [PATCH 5/7] ENH: Make ValueIndices.to_list much more efficient --- nibabel/metasum.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nibabel/metasum.py b/nibabel/metasum.py index d1e84dad55..2312dc2df4 100644 --- a/nibabel/metasum.py +++ b/nibabel/metasum.py @@ -165,7 +165,15 @@ def get_value(self, idx): def to_list(self): '''Convert back to a list of values''' - return [self.get_value(i) for i in range(self.n_input)] + if self._const_val is not _NoValue: + return [self._const_val] * self._n_input + res = [_NoValue] * self._n_input + for val, idx in self._unique_vals.items(): + res[idx] = val + for val, ba in self._val_bitarrs.items(): + for idx in self._extract_indices(ba): + res[idx] = val + return res def extend(self, values): '''Add more values to the end of any existing ones''' From 8946f163e5b9cfe4a140d353d5f37dcfdd31fe8d Mon Sep 17 00:00:00 2001 From: moloney Date: Mon, 12 Jul 2021 12:24:38 -0700 Subject: [PATCH 6/7] BF: Add dataclasses backport for 3.6 Co-authored-by: Chris Markiewicz --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 23ac6fce0b..c84e8cc894 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ install_requires = numpy >=1.13 packaging >=14.3 bitarray + dataclasses ; python_version < "3.7" zip_safe = False packages = find: From 59fab276d0730989c516158672cd54be7a12bf06 Mon Sep 17 00:00:00 2001 From: Brendan Moloney Date: Mon, 12 Jul 2021 20:29:25 -0700 Subject: [PATCH 7/7] ENH: Get the nd_sort method mostly working w/ basic tests --- nibabel/metasum.py | 60 +++++++++++++++++++------- nibabel/tests/test_metasum.py | 79 +++++++++++++++++++++++++++++++++-- 2 files changed, 120 insertions(+), 19 deletions(-) diff --git a/nibabel/metasum.py b/nibabel/metasum.py index 2312dc2df4..2183cc249a 100644 --- a/nibabel/metasum.py +++ b/nibabel/metasum.py @@ -100,6 +100,15 @@ def __getitem__(self, value): ba = self._val_bitarrs[value] return list(self._extract_indices(ba)) + def first(self, value): + '''Return the first index where this value appears''' + if self._const_val == value: + return 0 + idx = self._unique_vals.get(value) + if idx is not None: + return idx + return self._val_bitarrs[value].index(True) + def values(self): '''Generate each unique value that has been seen''' if self._const_val is not _NoValue: @@ -339,8 +348,15 @@ def get_block_size(self): return None return block_size - def is_subpartition(self, other): - '''''' + def is_orthogonal(self, other, size=1): + '''Check our value's indices overlaps each from `other` exactly `size` times + ''' + other_bas = {v: other.get_mask(v) for v in other.values()} + for val in self.values(): + for other_val, other_ba in other_bas.items(): + if self.count(val, mask=other_ba) != size: + return False + return True def _extract_indices(self, ba): '''Generate integer indices from bitarray representation''' @@ -416,7 +432,7 @@ def _extend_const(self, other): class DimTypes(IntEnum): - '''Enmerate the three types of nD dimensions''' + '''Enumerate the three types of nD dimensions''' SLICE = 1 TIME = 2 PARAM = 3 @@ -556,8 +572,9 @@ def nd_sort(self, dims): last_dim = dim # Pull out info about different types of dims - n_slices = None - n_vol = None + n_input = self._n_input + total_vol = None + slice_dim = None time_dim = None param_dims = [] n_params = [] @@ -568,9 +585,10 @@ def nd_sort(self, dims): dim_vidx = self._v_idxs[dim.key] dim_type = dim.dim_type if dim_type is DimTypes.SLICE: + slice_dim = dim n_slices = len(dim_vidx) - n_vol = dim_vidx.get_block_size() - if n_vol is None: + total_vol = dim_vidx.get_block_size() + if total_vol is None: raise NdSortError("There are missing or extra slices") shape.append(n_slices) curr_size *= n_slices @@ -583,29 +601,39 @@ def nd_sort(self, dims): n_param = len(dim_vidx) n_params.append(n_param) total_params *= n_param - if n_vol is None: - n_vol = self._n_input + if total_vol is None: + total_vol = n_input - # Size of the time dimension must be infered from the size of the other dims + # Size of the time dimension must be inferred from the size of the other dims n_time = 1 + prev_dim = slice_dim if time_dim is not None: - n_time, rem = divmod(n_vol, total_params) + n_time, rem = divmod(total_vol, total_params) if rem != 0: - raise NdSortError(f"The combined parameters don't evenly divide inputs") + raise NdSortError("The combined parameters don't evenly divide inputs") shape.append(n_time) curr_size *= n_time + prev_dim = time_dim - # Complete the "shape", and do a more detailed check that our param dims make sense + # Complete the "shape", and do a more detailed check that our dims make sense for dim, n_param in zip(param_dims, n_params): dim_vidx = self._v_idxs[dim.key] - if dim_vidx.get_block_size() != curr_size: + if dim_vidx.get_block_size() != n_input // n_param: raise NdSortError(f"The parameter {dim.key} doesn't evenly divide inputs") + if prev_dim is not None and prev_dim.dim_type != DimTypes.TIME: + count_per = (curr_size // shape[-1]) * (n_input // (curr_size * n_param)) + if not self._v_idxs[prev_dim.key].is_orthogonal(dim_vidx, count_per): + raise NdSortError("The dimensions are not orthogonal") shape.append(n_param) curr_size *= n_param + prev_dim = dim # Extract dim keys for each input and do the actual sort sort_keys = [(idx, tuple(self.get_val(idx, dim.key) for dim in reversed(dims))) - for idx in range(self._n_input)] + for idx in range(n_input)] sort_keys.sort(key=lambda x: x[1]) - # TODO: Finish this + # TODO: If we have non-singular time dimension we need to do some additional + # validation checks here after sorting. + + return tuple(shape), [x[0] for x in sort_keys] diff --git a/nibabel/tests/test_metasum.py b/nibabel/tests/test_metasum.py index c0aced4d2a..258fb76e87 100644 --- a/nibabel/tests/test_metasum.py +++ b/nibabel/tests/test_metasum.py @@ -1,6 +1,9 @@ -from ..metasum import MetaSummary, ValueIndices +import random import pytest +import numpy as np + +from ..metasum import DimIndex, DimTypes, MetaSummary, ValueIndices vidx_test_patterns = ([0] * 8, @@ -14,7 +17,7 @@ @pytest.mark.parametrize("in_list", vidx_test_patterns) def test_value_indices_basics(in_list): - '''Test we can roundtrip list -> ValueIndices -> list''' + '''Test basic ValueIndices behavior''' vidx = ValueIndices(in_list) assert vidx.n_input == len(in_list) assert len(vidx) == len(set(in_list)) @@ -22,7 +25,7 @@ def test_value_indices_basics(in_list): for val in vidx.values(): assert vidx.count(val) == in_list.count(val) for in_idx in vidx[val]: - assert in_list[in_idx] == val + assert in_list[in_idx] == val == vidx.get_value(in_idx) out_list = vidx.to_list() assert in_list == out_list @@ -78,3 +81,73 @@ def test_meta_summary_basics(in_dicts): assert out_dict == in_dicts[in_idx] for key, in_val in in_dicts[in_idx].items(): assert in_val == msum.get_val(in_idx, key) + + +def _make_nd_meta(shape, dim_info, const_meta=None): + if const_meta is None: + const_meta = {'series_number': '5'} + meta_seq = [] + for nd_idx in np.ndindex(*shape): + curr_meta = {} + curr_meta.update(const_meta) + for dim, dim_idx in zip(dim_info, nd_idx): + curr_meta[dim.key] = dim_idx + meta_seq.append(curr_meta) + return meta_seq + + +ndsort_test_args = (((3,), + (DimIndex(DimTypes.SLICE, 'slice_location'),), + None), + ((3, 5), + (DimIndex(DimTypes.SLICE, 'slice_location'), + DimIndex(DimTypes.TIME, 'acq_time')), + None), + ((3, 5), + (DimIndex(DimTypes.SLICE, 'slice_location'), + DimIndex(DimTypes.PARAM, 'inversion_time')), + None), + ((3, 5, 7), + (DimIndex(DimTypes.SLICE, 'slice_location'), + DimIndex(DimTypes.TIME, 'acq_time'), + DimIndex(DimTypes.PARAM, 'echo_time')), + None), + ((3, 5, 7), + (DimIndex(DimTypes.SLICE, 'slice_location'), + DimIndex(DimTypes.PARAM, 'inversion_time'), + DimIndex(DimTypes.PARAM, 'echo_time')), + None), + ((5, 3), + (DimIndex(DimTypes.TIME, 'acq_time'), + DimIndex(DimTypes.PARAM, 'echo_time')), + None), + ((3, 5, 7), + (DimIndex(DimTypes.TIME, 'acq_time'), + DimIndex(DimTypes.PARAM, 'inversion_time'), + DimIndex(DimTypes.PARAM, 'echo_time')), + None), + ((5, 7), + (DimIndex(DimTypes.PARAM, 'inversion_time'), + DimIndex(DimTypes.PARAM, 'echo_time')), + None), + ((5, 7, 3), + (DimIndex(DimTypes.PARAM, 'inversion_time'), + DimIndex(DimTypes.PARAM, 'echo_time'), + DimIndex(DimTypes.PARAM, 'repetition_time')), + None), + ) + + +@pytest.mark.parametrize("shape,dim_info,const_meta", ndsort_test_args) +def test_ndsort(shape, dim_info, const_meta): + meta_seq = _make_nd_meta(shape, dim_info, const_meta) + rand_idx_seq = [(i, m) for i, m in enumerate(meta_seq)] + # TODO: Use some pytest plugin to manage randomness? Just use fixed seed? + random.shuffle(rand_idx_seq) + rand_idx = [x[0] for x in rand_idx_seq] + rand_seq = [x[1] for x in rand_idx_seq] + msum = MetaSummary() + for meta in rand_seq: + msum.append(meta) + out_shape, out_idxs = msum.nd_sort(dim_info) + assert shape == out_shape