diff --git a/src/hangar/arrayset.py b/src/hangar/arrayset.py index ebf489d8..cc9e856f 100644 --- a/src/hangar/arrayset.py +++ b/src/hangar/arrayset.py @@ -1,3 +1,4 @@ +from collections import defaultdict import os import warnings from multiprocessing import cpu_count, get_context @@ -304,6 +305,18 @@ def backend_opts(self): """ return self._dflt_backend_opts + @property + def sample_classes(self): + grouped_spec_names = defaultdict(list) + for name, bespec in self._sspecs.items(): + grouped_spec_names[bespec].append(name) + + grouped_data_names = {} + for spec, names in grouped_spec_names.items(): + data = self._fs[spec.backend].read_data(spec) + grouped_data_names[tuple(data.tolist())] = names + return grouped_data_names + def keys(self, local: bool = False) -> Iterator[Union[str, int]]: """generator which yields the names of every sample in the arrayset diff --git a/src/hangar/dataloaders/common.py b/src/hangar/dataloaders/common.py index c711785b..aba41645 100644 --- a/src/hangar/dataloaders/common.py +++ b/src/hangar/dataloaders/common.py @@ -33,7 +33,7 @@ def __init__(self, if len(arraysets) == 0: raise ValueError('len(arraysets) cannot == 0') - aset_lens = set() + # aset_lens = set() all_keys = [] all_remote_keys = [] for aset in arraysets: @@ -41,12 +41,12 @@ def __init__(self, raise TypeError(f'Cannot load arraysets opened in `write-enabled` checkout.') self.arrayset_array.append(aset) self.arrayset_names.append(aset.name) - aset_lens.add(len(aset)) + # aset_lens.add(len(aset)) all_keys.append(set(aset.keys())) all_remote_keys.append(set(aset.remote_reference_keys)) - if len(aset_lens) > 1: - warnings.warn('Arraysets do not contain equal number of samples', UserWarning) + # if len(aset_lens) > 1: + # warnings.warn('Arraysets do not contain equal number of samples', UserWarning) common_keys = set.intersection(*all_keys) remote_keys = set.union(*all_remote_keys) diff --git a/src/hangar/dataloaders/sampler.py b/src/hangar/dataloaders/sampler.py new file mode 100644 index 00000000..4fce7c38 --- /dev/null +++ b/src/hangar/dataloaders/sampler.py @@ -0,0 +1,217 @@ +import numpy as np + + +def pnorm(p): + if not isinstance(p, (list, tuple)): + raise ValueError(f'probability map {p} must be of type (list, tuple), not {type(p)}') + ptot = np.sum(p) + if not np.allclose(ptot, 1): + p = [i / ptot for i in p] + return p + + +def multinomial(num_samples, p): + valid_p = pnorm(p) + res = np.random.multinomial(num_samples, valid_p) + return res + + +class Sampler(object): + r"""Base class for all Samplers. + Every Sampler subclass has to provide an :meth:`__iter__` method, providing a + way to iterate over indices of dataset elements, and a :meth:`__len__` method + that returns the length of the returned iterators. + .. note:: The :meth:`__len__` method isn't strictly required by + :class:`~torch.utils.data.DataLoader`, but is expected in any + calculation involving the length of a :class:`~torch.utils.data.DataLoader`. + """ + + def __init__(self, data_source): + pass + + def __iter__(self): + raise NotImplementedError + + # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # + # Many times we have an abstract class representing a collection/iterable of + # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally + # implementing a `__len__` method. In such cases, we must make sure to not + # provide a default implementation, because both straightforward default + # implementations have their issues: + # + # + `return NotImplemented`: + # Calling `len(subclass_instance)` raises: + # TypeError: 'NotImplementedType' object cannot be interpreted as an integer + # + # + `raise NotImplementedError()`: + # This prevents triggering some fallback behavior. E.g., the built-in + # `list(X)` tries to call `len(X)` first, and executes a different code + # path if the method is not found or `NotImplemented` is returned, while + # raising an `NotImplementedError` will propagate and and make the call + # fail where it could have use `__iter__` to complete the call. + # + # Thus, the only two sensible things to do are + # + # + **not** provide a default `__len__`. + # + # + raise a `TypeError` instead, which is what Python uses when users call + # a method that is not defined on an object. + # (@ssnl verifies that this works on at least Python 3.7.) + + +class SequentialSampler(Sampler): + r"""Samples elements sequentially, always in the same order. + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(self.data_source.keys()) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + If with replacement, then user can specify :attr:`num_samples` to draw. + Arguments: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. This argument + is supposed to be specified only when `replacement` is ``True``. + """ + + def __init__(self, data_source, replacement=False, num_samples=None): + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + + if not isinstance(self.replacement, bool): + raise ValueError("replacement should be a boolean value, but got " + "replacement={}".format(self.replacement)) + + if self._num_samples is not None and not replacement: + raise ValueError("With replacement=False, num_samples should not be specified, " + "since a random permute will be performed.") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples)) + + @property + def num_samples(self): + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + keys = list(self.data_source.keys()) + if self.replacement: + choose = np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist() + return (keys[x] for x in choose) + choose = np.random.permutation(self.num_samples) + return (keys[x] for x in choose) + + def __len__(self): + return self.num_samples + + +class SubsetRandomSampler(Sampler): + r"""Samples elements randomly from a given list of indices, without replacement. + Arguments: + indices (sequence): a sequence of indices + """ + + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + choose = np.random.permutation(len(self.indices)) + return (self.indices[x] for x in choose) + + def __len__(self): + return len(self.indices) + + +class WeightedRandomSampler(Sampler): + r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). + Args: + weights (sequence) : a sequence of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + Example: + >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + [0, 0, 0, 1, 0] + >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + [0, 1, 4, 3, 2] + """ + + def __init__(self, weights, num_samples): + if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ + num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) + self.weights = tuple(weights) + self.num_samples = num_samples + + def __iter__(self): + return iter(multinomial(self.num_samples, self.weights)) + + def __len__(self): + return self.num_samples + + +class BatchSampler(Sampler): + r"""Wraps another sampler to yield a mini-batch of indices. + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + Example: + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, sampler, batch_size, drop_last): + if not isinstance(sampler, Sampler): + raise ValueError("sampler should be an instance of " + "torch.utils.data.Sampler, but got sampler={}" + .format(sampler)) + if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ + batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last)) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size \ No newline at end of file