Skip to content

Commit

Permalink
Add base methods for iterator serialization (#924)
Browse files Browse the repository at this point in the history
Co-authored-by: Ryan Ly <rly@lbl.gov>
  • Loading branch information
CodyCBakerPhD and rly authored Aug 8, 2023
1 parent ca7722f commit 3f3586a
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features and minor improvements
- Increase raw data chunk cache size for reading HDF5 files from 1 MiB to 32 MiB. @bendichter, @rly [#925](https://github.com/hdmf-dev/hdmf/pull/925)
- Increase default chunk size for `GenericDataChunkIterator` from 1 MB to 10 MB. @bendichter, @rly [#925](https://github.com/hdmf-dev/hdmf/pull/925)
- Added the magic `__reduce__` method as well as two private semi-abstract helper methods to enable pickling of the `GenericDataChunkIterator`. @codycbakerphd [#924](https://github.com/hdmf-dev/hdmf/pull/924)

## HDMF 3.8.1 (July 25, 2023)

Expand Down
3 changes: 1 addition & 2 deletions src/hdmf/backends/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def read(self, **kwargs):

return container

@docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'},
allow_extra=True)
@docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, allow_extra=True)
def write(self, **kwargs):
"""Write a container to the IO source."""
container = popargs('container', kwargs)
Expand Down
54 changes: 35 additions & 19 deletions src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from warnings import warn
from typing import Tuple
from typing import Tuple, Callable
from itertools import product, chain

import h5py
Expand Down Expand Up @@ -190,9 +190,10 @@ def __init__(self, **kwargs):
HDF5 recommends chunk size in the range of 2 to 16 MB for optimal cloud performance.
https://youtu.be/rcS5vt-mKok?t=621
"""
buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, self.progress_bar_options = getargs(
buffer_gb, buffer_shape, chunk_mb, chunk_shape, self.display_progress, progress_bar_options = getargs(
"buffer_gb", "buffer_shape", "chunk_mb", "chunk_shape", "display_progress", "progress_bar_options", kwargs
)
self.progress_bar_options = progress_bar_options or dict()

if buffer_gb is None and buffer_shape is None:
buffer_gb = 1.0
Expand Down Expand Up @@ -264,15 +265,13 @@ def __init__(self, **kwargs):
)

if self.display_progress:
if self.progress_bar_options is None:
self.progress_bar_options = dict()

try:
from tqdm import tqdm

if "total" in self.progress_bar_options:
warn("Option 'total' in 'progress_bar_options' is not allowed to be over-written! Ignoring.")
self.progress_bar_options.pop("total")

self.progress_bar = tqdm(total=self.num_buffers, **self.progress_bar_options)
except ImportError:
warn(
Expand Down Expand Up @@ -345,12 +344,6 @@ def _get_default_buffer_shape(self, **kwargs) -> Tuple[int, ...]:
]
)

def recommended_chunk_shape(self) -> Tuple[int, ...]:
return self.chunk_shape

def recommended_data_shape(self) -> Tuple[int, ...]:
return self.maxshape

def __iter__(self):
return self

Expand All @@ -371,6 +364,11 @@ def __next__(self):
self.progress_bar.write("\n") # Allows text to be written to new lines after completion
raise StopIteration

def __reduce__(self) -> Tuple[Callable, Iterable]:
instance_constructor = self._from_dict
initialization_args = (self._to_dict(),)
return (instance_constructor, initialization_args)

@abstractmethod
def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
"""
Expand All @@ -391,24 +389,42 @@ def _get_data(self, selection: Tuple[slice]) -> np.ndarray:
"""
raise NotImplementedError("The data fetching method has not been built for this DataChunkIterator!")

@property
def maxshape(self) -> Tuple[int, ...]:
return self._maxshape

@abstractmethod
def _get_maxshape(self) -> Tuple[int, ...]:
"""Retrieve the maximum bounds of the data shape using minimal I/O."""
raise NotImplementedError("The setter for the maxshape property has not been built for this DataChunkIterator!")

@property
def dtype(self) -> np.dtype:
return self._dtype

@abstractmethod
def _get_dtype(self) -> np.dtype:
"""Retrieve the dtype of the data using minimal I/O."""
raise NotImplementedError("The setter for the internal dtype has not been built for this DataChunkIterator!")

def _to_dict(self) -> dict:
"""Optional method to add in child classes to enable pickling (required for multiprocessing)."""
raise NotImplementedError(
"The `._to_dict()` method for pickling has not been defined for this DataChunkIterator!"
)

@staticmethod
def _from_dict(self) -> Callable:
"""Optional method to add in child classes to enable pickling (required for multiprocessing)."""
raise NotImplementedError(
"The `._from_dict()` method for pickling has not been defined for this DataChunkIterator!"
)

def recommended_chunk_shape(self) -> Tuple[int, ...]:
return self.chunk_shape

def recommended_data_shape(self) -> Tuple[int, ...]:
return self.maxshape

@property
def maxshape(self) -> Tuple[int, ...]:
return self._maxshape
@property
def dtype(self) -> np.dtype:
return self._dtype


class DataChunkIterator(AbstractDataChunkIterator):
"""
Expand Down
60 changes: 59 additions & 1 deletion tests/unit/utils_test/test_core_GenericDataChunkIterator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import unittest
import pickle
import numpy as np
from pathlib import Path
from tempfile import mkdtemp
from shutil import rmtree
from typing import Tuple, Iterable
from typing import Tuple, Iterable, Callable
from sys import version_info

import h5py
from numpy.testing import assert_array_equal

from hdmf.data_utils import GenericDataChunkIterator
from hdmf.testing import TestCase
Expand All @@ -18,6 +20,30 @@
TQDM_INSTALLED = False


class TestPickleableNumpyArrayDataChunkIterator(GenericDataChunkIterator):
def __init__(self, array: np.ndarray, **kwargs):
self.array = array
self._kwargs = kwargs
super().__init__(**kwargs)

def _get_data(self, selection) -> np.ndarray:
return self.array[selection]

def _get_maxshape(self) -> Tuple[int, ...]:
return self.array.shape

def _get_dtype(self) -> np.dtype:
return self.array.dtype

def _to_dict(self) -> dict:
return dict(array=pickle.dumps(self.array), kwargs=self._kwargs)

@staticmethod
def _from_dict(dictionary: dict) -> Callable:
array = pickle.loads(dictionary["array"])
return TestPickleableNumpyArrayDataChunkIterator(array=array, **dictionary["kwargs"])


class GenericDataChunkIteratorTests(TestCase):
class TestNumpyArrayDataChunkIterator(GenericDataChunkIterator):
def __init__(self, array: np.ndarray, **kwargs):
Expand Down Expand Up @@ -204,6 +230,29 @@ def test_progress_bar_assertion(self):
progress_bar_options=dict(total=5),
)

def test_private_to_dict_assertion(self):
with self.assertRaisesWith(
exc_type=NotImplementedError,
exc_msg="The `._to_dict()` method for pickling has not been defined for this DataChunkIterator!"
):
iterator = self.TestNumpyArrayDataChunkIterator(array=self.test_array)
_ = iterator._to_dict()

def test_private_from_dict_assertion(self):
with self.assertRaisesWith(
exc_type=NotImplementedError,
exc_msg="The `._from_dict()` method for pickling has not been defined for this DataChunkIterator!"
):
_ = self.TestNumpyArrayDataChunkIterator._from_dict(dict())

def test_direct_pickle_assertion(self):
with self.assertRaisesWith(
exc_type=NotImplementedError,
exc_msg="The `._to_dict()` method for pickling has not been defined for this DataChunkIterator!"
):
iterator = self.TestNumpyArrayDataChunkIterator(array=self.test_array)
_ = pickle.dumps(iterator)

def test_maxshape_attribute_contains_int_type(self):
"""Motivated by issues described in https://github.com/hdmf-dev/hdmf/pull/780 & 781 regarding return types."""
self.check_all_of_iterable_is_python_int(
Expand Down Expand Up @@ -377,3 +426,12 @@ def test_tqdm_not_installed(self):
display_progress=True,
)
self.assertFalse(dci.display_progress)

def test_pickle(self):
pre_dump_iterator = TestPickleableNumpyArrayDataChunkIterator(array=self.test_array)
post_dump_iterator = pickle.loads(pickle.dumps(pre_dump_iterator))

assert isinstance(post_dump_iterator, TestPickleableNumpyArrayDataChunkIterator)
assert post_dump_iterator.chunk_shape == pre_dump_iterator.chunk_shape
assert post_dump_iterator.buffer_shape == pre_dump_iterator.buffer_shape
assert_array_equal(post_dump_iterator.array, pre_dump_iterator.array)

0 comments on commit 3f3586a

Please sign in to comment.