Skip to content

Commit

Permalink
Add cache option in GridPatchDataset (Project-MONAI#7180)
Browse files Browse the repository at this point in the history
Part of Project-MONAI#6904

### Description
- Fix inefficient patching in `PatchDataset`
- Add cache option in `GridPatchDataset`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Mark Graham <markgraham539@gmail.com>
  • Loading branch information
3 people authored and marksgraham committed Jan 30, 2024
1 parent 8f781eb commit bf166b5
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 46 deletions.
218 changes: 185 additions & 33 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,30 @@

from __future__ import annotations

from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
import sys
import warnings
from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence
from copy import deepcopy
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from typing import TYPE_CHECKING

import numpy as np
import torch

from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayTensor
from monai.data.dataset import Dataset
from monai.data.iterable_dataset import IterableDataset
from monai.data.utils import iter_patch
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple, first
from monai.data.utils import iter_patch, pickle_hashing
from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous
from monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import

if TYPE_CHECKING:
from tqdm import tqdm

has_tqdm = True
else:
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")

__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"]

Expand Down Expand Up @@ -184,6 +196,25 @@ class GridPatchDataset(IterableDataset):
see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`.
transform: a callable data transform operates on the patches.
with_coordinates: whether to yield the coordinates of each patch, default to `True`.
cache: whether to use cache mache mechanism, default to `False`.
see also: :py:class:`monai.data.CacheDataset`.
cache_num: number of items to be cached. Default is `sys.maxsize`.
will take the minimum of (cache_num, data_length x cache_rate, data_length).
cache_rate: percentage of cached data in total, default is 1.0 (cache all).
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads if computing cache in the initialization.
If num_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
"""

Expand All @@ -193,27 +224,148 @@ def __init__(
patch_iter: Callable,
transform: Callable | None = None,
with_coordinates: bool = True,
cache: bool = False,
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int | None = 1,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
hash_func: Callable[..., bytes] = pickle_hashing,
) -> None:
super().__init__(data=data, transform=None)
if transform is not None and not isinstance(transform, Compose):
transform = Compose(transform)
self.patch_iter = patch_iter
self.patch_transform = transform
self.with_coordinates = with_coordinates
self.set_num = cache_num
self.set_rate = cache_rate
self.progress = progress
self.copy_cache = copy_cache
self.as_contiguous = as_contiguous
self.hash_func = hash_func
self.num_workers = num_workers
if self.num_workers is not None:
self.num_workers = max(int(self.num_workers), 1)
self._cache: list | ListProxy = []
self._cache_other: list | ListProxy = []
self.cache = cache
self.first_random: int | None = None
if self.patch_transform is not None:
self.first_random = self.patch_transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)

def __iter__(self):
for image in super().__iter__():
for patch, *others in self.patch_iter(image):
out_patch = patch
if self.patch_transform is not None:
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
yield out_patch, others[0]
else:
yield out_patch
if self.cache:
if isinstance(data, Iterator):
raise TypeError("Data can not be iterator when cache is True")
self.set_data(data) # type: ignore

def set_data(self, data: Sequence) -> None:
"""
Set the input data and run deterministic transforms to generate cache content.
Note: should call this func after an entire epoch and must set `persistent_workers=False`
in PyTorch DataLoader, because it needs to create new worker processes based on new
generated cache content.
"""
self.data = data

# only compute cache for the unique items of dataset, and record the last index for duplicated items
mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}
self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping))
self._hash_keys = list(mapping)[: self.cache_num]
indices = list(mapping.values())[: self.cache_num]
self._cache, self._cache_other = zip(*self._fill_cache(indices)) # type: ignore

def _fill_cache(self, indices=None) -> list:
"""
Compute and fill the cache content from data source.
Args:
indices: target indices in the `self.data` source to compute cache.
if None, use the first `cache_num` items.
"""
if self.cache_num <= 0:
return []
if indices is None:
indices = list(range(self.cache_num))
if self.progress and not has_tqdm:
warnings.warn("tqdm is not installed, will not show the caching progress bar.")

pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v)
with ThreadPool(self.num_workers) as p:
return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset"))

def _load_cache_item(self, idx: int):
"""
Args:
idx: the index of the input data sequence.
"""
item = self.data[idx] # type: ignore
patch_cache, other_cache = [], []
for patch, *others in self.patch_iter(item):
if self.first_random is not None:
patch = self.patch_transform(patch, end=self.first_random, threading=True) # type: ignore

if self.as_contiguous:
patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format)
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
other_cache.append(others[0])
patch_cache.append(patch)
return patch_cache, other_cache

def _generate_patches(self, src, **apply_args):
"""
yield patches optionally post-processed by transform.
Args:
src: a iterable of image patches.
apply_args: other args for `self.patch_transform`.
"""
for patch, *others in src:
out_patch = patch
if self.patch_transform is not None:
out_patch = self.patch_transform(patch, **apply_args)
if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords
yield out_patch, others[0]
else:
yield out_patch

class PatchDataset(Dataset):
def __iter__(self):
if self.cache:
cache_index = None
for image in super().__iter__():
key = self.hash_func(image)
if key in self._hash_keys:
# if existing in cache, try to get the index in cache
cache_index = self._hash_keys.index(key)
if cache_index is None:
# no cache for this index, execute all the transforms directly
yield from self._generate_patches(self.patch_iter(image))
else:
if self._cache is None:
raise RuntimeError(
"Cache buffer is not initialized, please call `set_data()` before epoch begins."
)
data = self._cache[cache_index] # type: ignore
other = self._cache_other[cache_index] # type: ignore

# load data from cache and execute from the first random transform
data = deepcopy(data) if self.copy_cache else data
yield from self._generate_patches(zip(data, other), start=self.first_random)
else:
for image in super().__iter__():
yield from self._generate_patches(self.patch_iter(image))


class PatchDataset(IterableDataset):
"""
returns a patch from an image dataset.
Yields patches from data read from an image dataset.
The patches are generated by a user-specified callable `patch_func`,
and are optionally post-processed by `transform`.
For example, to generate random patch samples from an image dataset:
Expand Down Expand Up @@ -263,26 +415,26 @@ def __init__(
samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
transform: transform applied to each patch.
"""
super().__init__(data=data, transform=transform)
super().__init__(data=data, transform=None)

self.patch_func = patch_func
if samples_per_image <= 0:
raise ValueError("sampler_per_image must be a positive integer.")
self.samples_per_image = int(samples_per_image)
self.patch_transform = transform

def __len__(self) -> int:
return len(self.data) * self.samples_per_image

def _transform(self, index: int):
image_id = int(index / self.samples_per_image)
image = self.data[image_id]
patches = self.patch_func(image)
if len(patches) != self.samples_per_image:
raise RuntimeWarning(
f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
)
patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1)
patch = patches[patch_id]
if self.transform is not None:
patch = apply_transform(self.transform, patch, map_items=False)
return patch
return len(self.data) * self.samples_per_image # type: ignore

def __iter__(self):
for image in super().__iter__():
patches = self.patch_func(image)
if len(patches) != self.samples_per_image:
raise RuntimeWarning(
f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
)
for patch in patches:
out_patch = patch
if self.patch_transform is not None:
out_patch = apply_transform(self.patch_transform, patch, map_items=False)
yield out_patch
55 changes: 45 additions & 10 deletions tests/test_grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,18 @@ def test_shape(self):
self.assertEqual(sorted(output), sorted(expected))

def test_loading_array(self):
set_determinism(seed=1234)
# test sequence input data with images
images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
# image level
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234)
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity)
# use the grid patch dataset
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
Expand All @@ -129,9 +128,7 @@ def test_loading_array(self):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array(
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
),
np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
Expand Down Expand Up @@ -164,7 +161,7 @@ def test_loading_dict(self):
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
np.testing.assert_allclose(
item[0]["image"],
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
Expand All @@ -173,15 +170,53 @@ def test_loading_dict(self):
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
np.testing.assert_allclose(
item[0]["image"],
np.array(
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
),
np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5
)

def test_set_data(self):
from monai.transforms import Compose, Lambda, RandLambda

images = [np.arange(2, 18, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]

transform = Compose(
[Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False
)
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
dataset = GridPatchDataset(
data=images,
patch_iter=patch_iter,
transform=transform,
cache=True,
cache_rate=1.0,
copy_cache=not sys.platform == "linux",
)

num_workers = 2 if sys.platform == "linux" else 0
for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
# simulate another epoch, the cache content should not be modified
for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)

# update the datalist and fill the cache content
data_list2 = [np.arange(1, 17, dtype=float).reshape(1, 4, 4)]
dataset.set_data(data=data_list2)
# rerun with updated cache content
for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0], np.array([[[[91, 101], [131, 141]]], [[[111, 121], [151, 161]]]]), rtol=1e-4
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit bf166b5

Please sign in to comment.