Skip to content

Commit

Permalink
working - but extremly primitive - method to sample and batch arrayse…
Browse files Browse the repository at this point in the history
…t groups
  • Loading branch information
rlizzo committed Nov 26, 2019
1 parent 60300b6 commit 53e07f4
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 102 deletions.
18 changes: 10 additions & 8 deletions src/hangar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
def raise_ImportError(message, *args, **kwargs): # pragma: no cover
raise ImportError(message)

from .dataloaders.tfloader import make_tf_dataset
from .dataloaders.torchloader import make_torch_dataset

try: # pragma: no cover
from .dataloaders.tfloader import make_tf_dataset
except ImportError: # pragma: no cover
make_tf_dataset = partial(raise_ImportError, "Could not import tensorflow. Install dependencies")
# try: # pragma: no cover
# from .dataloaders.tfloader import make_tf_dataset
# except ImportError: # pragma: no cover
# make_tf_dataset = partial(raise_ImportError, "Could not import tensorflow. Install dependencies")

try: # pragma: no cover
from .dataloaders.torchloader import make_torch_dataset
except ImportError: # pragma: no cover
make_torch_dataset = partial(raise_ImportError, "Could not import torch. Install dependencies")
# try: # pragma: no cover
# from .dataloaders.torchloader import make_torch_dataset
# except ImportError: # pragma: no cover
# make_torch_dataset = partial(raise_ImportError, "Could not import torch. Install dependencies")
17 changes: 3 additions & 14 deletions src/hangar/arrayset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from .records.parsing import arrayset_record_schema_db_val_from_raw_val


CompatibleArray = NamedTuple(
'CompatibleArray', [('compatible', bool), ('reason', str)])
CompatibleArray = NamedTuple('CompatibleArray', [
('compatible', bool),
('reason', str)])


class ArraysetDataReader(object):
Expand Down Expand Up @@ -305,18 +306,6 @@ def backend_opts(self):
"""
return self._dflt_backend_opts

@property
def sample_classes(self):
grouped_spec_names = defaultdict(list)
for name, bespec in self._sspecs.items():
grouped_spec_names[bespec].append(name)

grouped_data_names = {}
for spec, names in grouped_spec_names.items():
data = self._fs[spec.backend].read_data(spec)
grouped_data_names[tuple(data.tolist())] = names
return grouped_data_names

def keys(self, local: bool = False) -> Iterator[Union[str, int]]:
"""generator which yields the names of every sample in the arrayset
Expand Down
3 changes: 3 additions & 0 deletions src/hangar/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .grouper import GroupedArraysetDataReader

__all__ = ['GroupedArraysetDataReader']
125 changes: 125 additions & 0 deletions src/hangar/dataloaders/grouper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import numpy as np

from ..arrayset import ArraysetDataReader

from collections import defaultdict
import hashlib
from typing import Sequence, Union, Iterable, NamedTuple
import struct


# -------------------------- typehints ---------------------------------------


ArraysetSampleNames = Sequence[Union[str, int]]

SampleGroup = NamedTuple('SampleGroup', [
('group', np.ndarray),
('samples', Union[str, int])])


# ------------------------------------------------------------------------------


def _calculate_hash_digest(data: np.ndarray) -> str:
hasher = hashlib.blake2b(data, digest_size=20)
hasher.update(struct.pack(f'<{len(data.shape)}QB', *data.shape, data.dtype.num))
digest = hasher.hexdigest()
return digest


class FakeNumpyKeyDict(object):
def __init__(self, group_spec_samples, group_spec_value, group_digest_spec):
self._group_spec_samples = group_spec_samples
self._group_spec_value = group_spec_value
self._group_digest_spec = group_digest_spec

def __getitem__(self, key: np.ndarray) -> ArraysetSampleNames:
digest = _calculate_hash_digest(key)
spec = self._group_digest_spec[digest]
samples = self._group_spec_samples[spec]
return samples

def get(self, key: np.ndarray) -> ArraysetSampleNames:
return self.__getitem__(key)

def __setitem__(self, key, val):
raise PermissionError('Not User Editable')

def __delitem__(self, key):
raise PermissionError('Not User Editable')

def __len__(self) -> int:
return len(self._group_digest_spec)

def __contains__(self, key: np.ndarray) -> bool:
digest = _calculate_hash_digest(key)
res = True if digest in self._group_digest_spec else False
return res

def __iter__(self) -> Iterable[np.ndarray]:
for spec in self._group_digest_spec.values():
yield self._group_spec_value[spec]

def keys(self) -> Iterable[np.ndarray]:
for spec in self._group_digest_spec.values():
yield self._group_spec_value[spec]

def values(self) -> Iterable[ArraysetSampleNames]:
for spec in self._group_digest_spec.values():
yield self._group_spec_samples[spec]

def items(self) -> Iterable[ArraysetSampleNames]:
for spec in self._group_digest_spec.values():
yield (self._group_spec_value[spec], self._group_spec_samples[spec])

def __repr__(self):
print('Mapping: Group Data Value -> Sample Name')
for k, v in self.items():
print(k, v)

def _repr_pretty_(self, p, cycle):
res = f'Mapping: Group Data Value -> Sample Name \n'
for k, v in self.items():
res += f'\n {k} :: {v}'
p.text(res)



# ---------------------------- MAIN METHOD ------------------------------------


class GroupedArraysetDataReader(object):
'''Pass in an arrayset and automatically find sample groups.
'''

def __init__(self, arrayset: ArraysetDataReader, *args, **kwargs):

self.__arrayset = arrayset # TODO: Do we actually need to keep this around?
self._group_spec_samples = defaultdict(list)
self._group_spec_value = {}
self._group_digest_spec = {}

self._setup()
self._group_samples = FakeNumpyKeyDict(
self._group_spec_samples,
self._group_spec_value,
self._group_digest_spec)

def _setup(self):
for name, bespec in self.__arrayset._sspecs.items():
self._group_spec_samples[bespec].append(name)
for spec, names in self._group_spec_samples.items():
data = self.__arrayset._fs[spec.backend].read_data(spec)
self._group_spec_value[spec] = data
digest = _calculate_hash_digest(data)
self._group_digest_spec[digest] = spec

@property
def groups(self) -> Iterable[np.ndarray]:
for spec in self._group_digest_spec.values():
yield self._group_spec_value[spec]

@property
def group_samples(self):
return self._group_samples
Loading

0 comments on commit 53e07f4

Please sign in to comment.