From 2c16613001b03927e23fecad08ce01ed14086b9f Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 1 Nov 2023 19:09:11 +0800 Subject: [PATCH 01/22] fix #6904 Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 157 +++++++++++++++++++++++++++++++++---- 1 file changed, 143 insertions(+), 14 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 06954e9f11..bb88ed6366 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -11,18 +11,34 @@ from __future__ import annotations +import sys from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence from copy import deepcopy +from multiprocessing.managers import ListProxy +from typing import TYPE_CHECKING import numpy as np +import torch from monai.config import KeysCollection from monai.config.type_definitions import NdarrayTensor -from monai.data.dataset import Dataset +from monai.data.dataset import CacheDataset, Dataset from monai.data.iterable_dataset import IterableDataset -from monai.data.utils import iter_patch -from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple, first +from monai.data.utils import iter_patch, pickle_hashing +from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous +from monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import + +if TYPE_CHECKING: + from tqdm import tqdm + + has_tqdm = True +else: + tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") + +cp, _ = optional_import("cupy") +lmdb, _ = optional_import("lmdb") +pd, _ = optional_import("pandas") +kvikio_numpy, _ = optional_import("kvikio.numpy") __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"] @@ -145,7 +161,7 @@ def __call__( yield ret, coords -class GridPatchDataset(IterableDataset): +class GridPatchDataset(IterableDataset, CacheDataset): """ Yields patches from data read from an image dataset. Typically used with `PatchIter` or `PatchIterd` so that the patches are chosen in a contiguous grid sampling scheme. @@ -193,22 +209,135 @@ def __init__( patch_iter: Callable, transform: Callable | None = None, with_coordinates: bool = True, + cache: bool = False, + cache_num: int = sys.maxsize, + cache_rate: float = 1.0, + num_workers: int | None = 1, + progress: bool = True, + copy_cache: bool = True, + as_contiguous: bool = True, + hash_func: Callable[..., bytes] = pickle_hashing, ) -> None: super().__init__(data=data, transform=None) + if transform is not None and not isinstance(transform, Compose): + transform = Compose(transform) self.patch_iter = patch_iter self.patch_transform = transform self.with_coordinates = with_coordinates + self.set_num = cache_num + self.set_rate = cache_rate + self.progress = progress + self.copy_cache = copy_cache + self.as_contiguous = as_contiguous + self.hash_func = hash_func + self.num_workers = num_workers + if self.num_workers is not None: + self.num_workers = max(int(self.num_workers), 1) + self._cache: list | ListProxy = [] + self._cache_other: list | ListProxy = [] + self.cache = cache + if self.cache: + self.set_data(data) + + def set_data(self, data: Sequence) -> None: + """ + Set the input data and run deterministic transforms to generate cache content. + + Note: should call this func after an entire epoch and must set `persistent_workers=False` + in PyTorch DataLoader, because it needs to create new worker processes based on new + generated cache content. + + """ + self.data = data + + def _compute_cache_num(data_len: int): + self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len) + + # only compute cache for the unique items of dataset, and record the last index for duplicated items + mapping = {self.hash_func(v): i for i, v in enumerate(self.data)} + _compute_cache_num(len(mapping)) + self._hash_keys = list(mapping)[: self.cache_num] + indices = list(mapping.values())[: self.cache_num] + + self._cache = self._fill_cache(indices) + return + + def _load_cache_item(self, idx: int): + """ + Args: + idx: the index of the input data sequence. + """ + item = self.data[idx] + patch_cache, other_cache = [], [] + for patch, *others in self.patch_iter(item): + if self.patch_transform is not None: + first_random = self.patch_transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + patch = self.patch_transform(patch, end=first_random, threading=True) + + if self.as_contiguous: + patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + other_cache.append(others[0]) + patch_cache.append(patch) + self._cache_other.append(other_cache) + return patch_cache def __iter__(self): - for image in super().__iter__(): - for patch, *others in self.patch_iter(image): - out_patch = patch - if self.patch_transform is not None: - out_patch = apply_transform(self.patch_transform, patch, map_items=False) - if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords - yield out_patch, others[0] - else: - yield out_patch + if self.cache: + cache_index = None + for image in super().__iter__(): + key = self.hash_func(image) + if key in self._hash_keys: + # if existing in cache, try to get the index in cache + cache_index = self._hash_keys.index(key) + if cache_index is None: + # no cache for this index, execute all the transforms directly + for patch, *others in self.patch_iter(image): + out_patch = patch + if self.patch_transform is not None: + out_patch = apply_transform(self.patch_transform, patch, map_items=False) + if ( + self.with_coordinates and len(others) > 0 + ): # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others[0] + else: + yield out_patch + + if self._cache is None: + raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.") + data = self._cache[cache_index] + other = self._cache_other[cache_index] + + # load data from cache and execute from the first random transform + if not isinstance(self.patch_transform, Compose): + raise ValueError("transform must be an instance of monai.transforms.Compose.") + + first_random = self.patch_transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + if first_random is not None: + data = deepcopy(data) if self.copy_cache is True else data + for out_patch, others in zip(data, other): + if self.patch_transform is not None: + out_patch = self.patch_transform(out_patch, start=first_random) + if ( + self.with_coordinates and len(others) > 0 + ): # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others + else: + yield out_patch + else: + for image in super().__iter__(): + for patch, *others in self.patch_iter(image): + out_patch = patch + if self.patch_transform is not None: + out_patch = apply_transform(self.patch_transform, patch, map_items=False) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others[0] + else: + yield out_patch class PatchDataset(Dataset): From 0e7a36241ad62ad186218e6c11a7b461127e94f4 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 13:30:53 +0800 Subject: [PATCH 02/22] modify test Signed-off-by: KumoLiu --- tests/test_grid_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index ba33547260..af81edd8fa 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -108,11 +108,11 @@ def test_shape(self): self.assertEqual(sorted(output), sorted(expected)) def test_loading_array(self): - set_determinism(seed=1234) + # set_determinism(seed=1234) # test sequence input data with images images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] # image level - patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) + patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234) patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity) # use the grid patch dataset @@ -120,7 +120,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]), + np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -130,7 +130,7 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]] + [[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]] ), rtol=1e-3, ) @@ -164,7 +164,7 @@ def test_loading_dict(self): self.assertListEqual(item[0]["metadata"], ["test string", "test string"]) np.testing.assert_allclose( item[0]["image"], - np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]), + np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -174,7 +174,7 @@ def test_loading_dict(self): np.testing.assert_allclose( item[0]["image"], np.array( - [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]] + [[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]] ), rtol=1e-3, ) From bf0649760b31e23d1c6a910dc657802d63e9ad94 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 13:31:11 +0800 Subject: [PATCH 03/22] fix ci Signed-off-by: KumoLiu --- tests/test_grid_dataset.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index af81edd8fa..69de2d729a 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -129,9 +129,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array( - [[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]] - ), + np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]), rtol=1e-3, ) np.testing.assert_allclose( @@ -173,9 +171,7 @@ def test_loading_dict(self): np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2)) np.testing.assert_allclose( item[0]["image"], - np.array( - [[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]] - ), + np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]), rtol=1e-3, ) np.testing.assert_allclose( From fa3da3876f916717fbaeb9099d7134445c522acb Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 14:25:53 +0800 Subject: [PATCH 04/22] fix mypy Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index bb88ed6366..1f45c23583 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -12,9 +12,11 @@ from __future__ import annotations import sys -from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence +import warnings +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence, Iterator from copy import deepcopy from multiprocessing.managers import ListProxy +from multiprocessing.pool import ThreadPool from typing import TYPE_CHECKING import numpy as np @@ -161,7 +163,7 @@ def __call__( yield ret, coords -class GridPatchDataset(IterableDataset, CacheDataset): +class GridPatchDataset(IterableDataset): """ Yields patches from data read from an image dataset. Typically used with `PatchIter` or `PatchIterd` so that the patches are chosen in a contiguous grid sampling scheme. @@ -237,7 +239,9 @@ def __init__( self._cache_other: list | ListProxy = [] self.cache = cache if self.cache: - self.set_data(data) + if isinstance(data, Iterator): + raise TypeError("Data can not be iterator when cache is True") + self.set_data(data) # type: ignore def set_data(self, data: Sequence) -> None: """ @@ -262,12 +266,32 @@ def _compute_cache_num(data_len: int): self._cache = self._fill_cache(indices) return + def _fill_cache(self, indices=None) -> list: + """ + Compute and fill the cache content from data source. + + Args: + indices: target indices in the `self.data` source to compute cache. + if None, use the first `cache_num` items. + + """ + if self.cache_num <= 0: + return [] + if indices is None: + indices = list(range(self.cache_num)) + if self.progress and not has_tqdm: + warnings.warn("tqdm is not installed, will not show the caching progress bar.") + with ThreadPool(self.num_workers) as p: + if self.progress and has_tqdm: + return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset")) + return list(p.imap(self._load_cache_item, indices)) + def _load_cache_item(self, idx: int): """ Args: idx: the index of the input data sequence. """ - item = self.data[idx] + item = self.data[idx] # type: ignore patch_cache, other_cache = [], [] for patch, *others in self.patch_iter(item): if self.patch_transform is not None: @@ -307,8 +331,8 @@ def __iter__(self): if self._cache is None: raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.") - data = self._cache[cache_index] - other = self._cache_other[cache_index] + data = self._cache[cache_index] # type: ignore + other = self._cache_other[cache_index] # type: ignore # load data from cache and execute from the first random transform if not isinstance(self.patch_transform, Compose): From 84a50c7d7633929d6f6fe772338d8931db3b1075 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 14:51:47 +0800 Subject: [PATCH 05/22] fix #6585 Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 35 ++++++++++++++++++----------------- tests/test_patch_dataset.py | 27 +++++++++++++++------------ 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 1f45c23583..37202aaa3b 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -364,7 +364,7 @@ def __iter__(self): yield out_patch -class PatchDataset(Dataset): +class PatchDataset(IterableDataset): """ returns a patch from an image dataset. The patches are generated by a user-specified callable `patch_func`, @@ -416,26 +416,27 @@ def __init__( samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements. transform: transform applied to each patch. """ - super().__init__(data=data, transform=transform) + super().__init__(data=data, transform=None) self.patch_func = patch_func if samples_per_image <= 0: raise ValueError("sampler_per_image must be a positive integer.") self.samples_per_image = int(samples_per_image) + self.patch_transform = transform def __len__(self) -> int: - return len(self.data) * self.samples_per_image - - def _transform(self, index: int): - image_id = int(index / self.samples_per_image) - image = self.data[image_id] - patches = self.patch_func(image) - if len(patches) != self.samples_per_image: - raise RuntimeWarning( - f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}." - ) - patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1) - patch = patches[patch_id] - if self.transform is not None: - patch = apply_transform(self.transform, patch, map_items=False) - return patch + return len(self.data) * self.samples_per_image #type: ignore + + def __iter__(self): + for image in super().__iter__(): + patches = self.patch_func(image) + if len(patches) != self.samples_per_image: + raise RuntimeWarning( + f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}." + ) + for patch in patches: + out_patch = patch + if self.patch_transform is not None: + out_patch = apply_transform(self.patch_transform, patch, map_items=False) + yield out_patch + diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 7d66bdccbb..612f379989 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -27,18 +27,21 @@ def identity(x): class TestPatchDataset(unittest.TestCase): - def test_shape(self): - test_dataset = ["vwxyz", "hello", "world"] - n_per_image = len(test_dataset[0]) + # def test_shape(self): + # test_dataset = ["vwxyz", "hello", "world"] + # n_per_image = len(test_dataset[0]) - result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image) + # result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image) - output = [] - n_workers = 0 if sys.platform == "win32" else 2 - for item in DataLoader(result, batch_size=3, num_workers=n_workers): - output.append("".join(item)) - expected = ["vwx", "yzh", "ell", "owo", "rld"] - self.assertEqual(output, expected) + # output = [] + # n_workers = 0 if sys.platform == "win32" else 2 + # for item in DataLoader(result, batch_size=3, num_workers=n_workers): + # output.append("".join(item)) + # if n_workers == 0: + # expected = ["vwx", "yzh", "ell", "owo", "rld"] + # else: + # expected = ["vwx", "hel", "yzw", "lo", "orl", "d"] + # self.assertEqual(output, expected) def test_loading_array(self): set_determinism(seed=1234) @@ -61,7 +64,7 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]] + [[[4.970372, 5.970372, 6.970372], [8.970372, 9.970372, 10.970372], [12.970372, 13.970372, 14.970372]]] ), rtol=1e-5, ) @@ -71,7 +74,7 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]] + [[[5.028125, 6.028125, 7.028125], [9.028125, 10.028125, 11.028125], [13.028125, 14.028125, 15.028125]]] ), rtol=1e-5, ) From 993ca74353ae3a9be12e16feb237d8cf93c9a7be Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 14:53:11 +0800 Subject: [PATCH 06/22] minor fix Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 37202aaa3b..c17bbc3902 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -24,7 +24,6 @@ from monai.config import KeysCollection from monai.config.type_definitions import NdarrayTensor -from monai.data.dataset import CacheDataset, Dataset from monai.data.iterable_dataset import IterableDataset from monai.data.utils import iter_patch, pickle_hashing from monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous @@ -425,7 +424,7 @@ def __init__( self.patch_transform = transform def __len__(self) -> int: - return len(self.data) * self.samples_per_image #type: ignore + return len(self.data) * self.samples_per_image # type: ignore def __iter__(self): for image in super().__iter__(): @@ -439,4 +438,3 @@ def __iter__(self): if self.patch_transform is not None: out_patch = apply_transform(self.patch_transform, patch, map_items=False) yield out_patch - From 66fc0b4acf43611a4b25aa6127fce8d3e8284351 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 14:53:36 +0800 Subject: [PATCH 07/22] fix flake8 Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 2 +- tests/test_patch_dataset.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index c17bbc3902..f53c3fbafa 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -13,7 +13,7 @@ import sys import warnings -from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence, Iterator +from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence from copy import deepcopy from multiprocessing.managers import ListProxy from multiprocessing.pool import ThreadPool diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 612f379989..266252c090 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -74,7 +74,13 @@ def test_loading_array(self): np.testing.assert_allclose( item[0], np.array( - [[[5.028125, 6.028125, 7.028125], [9.028125, 10.028125, 11.028125], [13.028125, 14.028125, 15.028125]]] + [ + [ + [5.028125, 6.028125, 7.028125], + [9.028125, 10.028125, 11.028125], + [13.028125, 14.028125, 15.028125], + ] + ] ), rtol=1e-5, ) From cf40cbc833aa076706d9ecacae3eb4e20736f66a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 2 Nov 2023 14:56:45 +0800 Subject: [PATCH 08/22] minor fix Signed-off-by: KumoLiu --- tests/test_patch_dataset.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 266252c090..eb705f0c61 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -27,21 +27,21 @@ def identity(x): class TestPatchDataset(unittest.TestCase): - # def test_shape(self): - # test_dataset = ["vwxyz", "hello", "world"] - # n_per_image = len(test_dataset[0]) + def test_shape(self): + test_dataset = ["vwxyz", "hello", "world"] + n_per_image = len(test_dataset[0]) - # result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image) + result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image) - # output = [] - # n_workers = 0 if sys.platform == "win32" else 2 - # for item in DataLoader(result, batch_size=3, num_workers=n_workers): - # output.append("".join(item)) - # if n_workers == 0: - # expected = ["vwx", "yzh", "ell", "owo", "rld"] - # else: - # expected = ["vwx", "hel", "yzw", "lo", "orl", "d"] - # self.assertEqual(output, expected) + output = [] + n_workers = 0 if sys.platform == "win32" else 2 + for item in DataLoader(result, batch_size=3, num_workers=n_workers): + output.append("".join(item)) + if n_workers == 0: + expected = ["vwx", "yzh", "ell", "owo", "rld"] + else: + expected = ["vwx", "hel", "yzw", "lo", "orl", "d"] + self.assertEqual(output, expected) def test_loading_array(self): set_determinism(seed=1234) From 7b79b341dcac6dec196322df8b704fb234db2acf Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 3 Nov 2023 13:53:49 +0800 Subject: [PATCH 09/22] Update monai/data/grid_dataset.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/grid_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index f53c3fbafa..6884d7dec9 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -341,7 +341,7 @@ def __iter__(self): lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) if first_random is not None: - data = deepcopy(data) if self.copy_cache is True else data + data = deepcopy(data) if self.copy_cache else data for out_patch, others in zip(data, other): if self.patch_transform is not None: out_patch = self.patch_transform(out_patch, start=first_random) From f3b3b98ce6286adf9675804a32ff005e11ecfbec Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 3 Nov 2023 20:49:04 +0800 Subject: [PATCH 10/22] address comments Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 41 +++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 6884d7dec9..6d7889f985 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -237,6 +237,11 @@ def __init__( self._cache: list | ListProxy = [] self._cache_other: list | ListProxy = [] self.cache = cache + if self.patch_transform is not None: + self.first_random = self.patch_transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + if self.cache: if isinstance(data, Iterator): raise TypeError("Data can not be iterator when cache is True") @@ -293,11 +298,8 @@ def _load_cache_item(self, idx: int): item = self.data[idx] # type: ignore patch_cache, other_cache = [], [] for patch, *others in self.patch_iter(item): - if self.patch_transform is not None: - first_random = self.patch_transform.get_index_of_first( - lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) - ) - patch = self.patch_transform(patch, end=first_random, threading=True) + if self.first_random is not None: + patch = self.patch_transform(patch, end=self.first_random, threading=True) if self.as_contiguous: patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format) @@ -334,29 +336,22 @@ def __iter__(self): other = self._cache_other[cache_index] # type: ignore # load data from cache and execute from the first random transform - if not isinstance(self.patch_transform, Compose): - raise ValueError("transform must be an instance of monai.transforms.Compose.") - - first_random = self.patch_transform.get_index_of_first( - lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) - ) - if first_random is not None: - data = deepcopy(data) if self.copy_cache else data - for out_patch, others in zip(data, other): - if self.patch_transform is not None: - out_patch = self.patch_transform(out_patch, start=first_random) - if ( - self.with_coordinates and len(others) > 0 - ): # patch_iter to yield at least 2 items: patch, coords - yield out_patch, others - else: - yield out_patch + data = deepcopy(data) if self.copy_cache else data + for out_patch, others in zip(data, other): + if self.first_random is not None: + out_patch = self.patch_transform(out_patch, start=self.first_random) + if ( + self.with_coordinates and len(others) > 0 + ): # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others + else: + yield out_patch else: for image in super().__iter__(): for patch, *others in self.patch_iter(image): out_patch = patch if self.patch_transform is not None: - out_patch = apply_transform(self.patch_transform, patch, map_items=False) + out_patch = self.patch_transform(patch) if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords yield out_patch, others[0] else: From caec7002ccd4e00b7e8cbb7d7d62cf260db5a715 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:50:44 +0000 Subject: [PATCH 11/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/grid_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 6d7889f985..89956b39ae 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -241,7 +241,7 @@ def __init__( self.first_random = self.patch_transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) ) - + if self.cache: if isinstance(data, Iterator): raise TypeError("Data can not be iterator when cache is True") From e15aeab48fd8b656d2bc268d7fe1467383e707be Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 8 Nov 2023 16:52:10 +0800 Subject: [PATCH 12/22] add `_generate_patches` Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 42 +++++++++++++------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 89956b39ae..52c5b8922e 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -309,6 +309,16 @@ def _load_cache_item(self, idx: int): self._cache_other.append(other_cache) return patch_cache + def _generate_patches(self, src, **apply_args): + for patch, *others in src: + out_patch = patch + if self.patch_transform is not None: + out_patch = self.patch_transform(patch, **apply_args) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others[0] + else: + yield out_patch + def __iter__(self): if self.cache: cache_index = None @@ -319,43 +329,19 @@ def __iter__(self): cache_index = self._hash_keys.index(key) if cache_index is None: # no cache for this index, execute all the transforms directly - for patch, *others in self.patch_iter(image): - out_patch = patch - if self.patch_transform is not None: - out_patch = apply_transform(self.patch_transform, patch, map_items=False) - if ( - self.with_coordinates and len(others) > 0 - ): # patch_iter to yield at least 2 items: patch, coords - yield out_patch, others[0] - else: - yield out_patch + yield from self._generate_patches(self.patch_iter(image)) if self._cache is None: - raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.") + raise RuntimeError("Cache buffer is not initialized, please call `set_data()` before epoch begins.") data = self._cache[cache_index] # type: ignore other = self._cache_other[cache_index] # type: ignore # load data from cache and execute from the first random transform data = deepcopy(data) if self.copy_cache else data - for out_patch, others in zip(data, other): - if self.first_random is not None: - out_patch = self.patch_transform(out_patch, start=self.first_random) - if ( - self.with_coordinates and len(others) > 0 - ): # patch_iter to yield at least 2 items: patch, coords - yield out_patch, others - else: - yield out_patch + yield from self._generate_patches(zip(data, other), start=self.first_random) else: for image in super().__iter__(): - for patch, *others in self.patch_iter(image): - out_patch = patch - if self.patch_transform is not None: - out_patch = self.patch_transform(patch) - if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords - yield out_patch, others[0] - else: - yield out_patch + yield from self._generate_patches(self.patch_iter(image)) class PatchDataset(IterableDataset): From acac03d6a223e65572899070339758e387a4a91d Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 8 Nov 2023 22:13:38 +0800 Subject: [PATCH 13/22] add unittests Signed-off-by: KumoLiu --- tests/test_grid_dataset.py | 41 +++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 69de2d729a..49f2efe865 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -108,7 +108,6 @@ def test_shape(self): self.assertEqual(sorted(output), sorted(expected)) def test_loading_array(self): - # set_determinism(seed=1234) # test sequence input data with images images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] # image level @@ -178,6 +177,46 @@ def test_loading_dict(self): item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5 ) + def test_set_data(self): + from monai.transforms import Compose, Lambda, RandLambda + + images = [np.arange(16, dtype=float).reshape(1, 4, 4)] + + transform = Compose( + [Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False + ) + patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) + dataset = GridPatchDataset( + data=images, + patch_iter=patch_iter, + transform=transform, + cache=True, + cache_rate=1.0, + copy_cache=not sys.platform == "linux", + ) + + num_workers = 2 if sys.platform == "linux" else 0 + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + # simulate another epoch, the cache content should not be modified + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + + # update the datalist and fill the cache content + data_list2 = [np.arange(1, 17, dtype=float).reshape(1, 4, 4)] + dataset.set_data(data=data_list2) + # rerun with updated cache content + for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose( + item[0], np.array([[[[91, 101], [131, 141]]], [[[111, 121], [151, 161]]]]), rtol=1e-4 + ) + np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) + if __name__ == "__main__": unittest.main() From 06effe5534721556c2813545f2042d3483211501 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 10 Nov 2023 15:46:08 +0800 Subject: [PATCH 14/22] Update monai/data/grid_dataset.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/grid_dataset.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 52c5b8922e..f9bd86e439 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -330,15 +330,15 @@ def __iter__(self): if cache_index is None: # no cache for this index, execute all the transforms directly yield from self._generate_patches(self.patch_iter(image)) - - if self._cache is None: - raise RuntimeError("Cache buffer is not initialized, please call `set_data()` before epoch begins.") - data = self._cache[cache_index] # type: ignore - other = self._cache_other[cache_index] # type: ignore - - # load data from cache and execute from the first random transform - data = deepcopy(data) if self.copy_cache else data - yield from self._generate_patches(zip(data, other), start=self.first_random) + else: + if self._cache is None: + raise RuntimeError("Cache buffer is not initialized, please call `set_data()` before epoch begins.") + data = self._cache[cache_index] # type: ignore + other = self._cache_other[cache_index] # type: ignore + + # load data from cache and execute from the first random transform + data = deepcopy(data) if self.copy_cache else data + yield from self._generate_patches(zip(data, other), start=self.first_random) else: for image in super().__iter__(): yield from self._generate_patches(self.patch_iter(image)) From 1bb1ae796c8bcc4774c84d9691a5319901de409c Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 10 Nov 2023 16:41:19 +0800 Subject: [PATCH 15/22] remove unused import Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index f9bd86e439..4e735bf3d2 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -36,11 +36,6 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") -cp, _ = optional_import("cupy") -lmdb, _ = optional_import("lmdb") -pd, _ = optional_import("pandas") -kvikio_numpy, _ = optional_import("kvikio.numpy") - __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"] @@ -332,7 +327,9 @@ def __iter__(self): yield from self._generate_patches(self.patch_iter(image)) else: if self._cache is None: - raise RuntimeError("Cache buffer is not initialized, please call `set_data()` before epoch begins.") + raise RuntimeError( + "Cache buffer is not initialized, please call `set_data()` before epoch begins." + ) data = self._cache[cache_index] # type: ignore other = self._cache_other[cache_index] # type: ignore From 520e2b3509c770cb0211c4831e799dc51034314d Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 10 Nov 2023 18:52:56 +0800 Subject: [PATCH 16/22] update docstring Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 4e735bf3d2..03b757911d 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -196,6 +196,25 @@ class GridPatchDataset(IterableDataset): see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`. transform: a callable data transform operates on the patches. with_coordinates: whether to yield the coordinates of each patch, default to `True`. + cache: whether to use cache mache mechanism, default to `False`. + see also: :py:class:`monai.data.CacheDataset`. + cache_num: number of items to be cached. Default is `sys.maxsize`. + will take the minimum of (cache_num, data_length x cache_rate, data_length). + cache_rate: percentage of cached data in total, default is 1.0 (cache all). + will take the minimum of (cache_num, data_length x cache_rate, data_length). + num_workers: the number of worker threads if computing cache in the initialization. + If num_workers is None then the number returned by os.cpu_count() is used. + If a value less than 1 is specified, 1 will be used instead. + progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. + it may help improve the performance of following logic. + hash_func: a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. """ @@ -305,6 +324,14 @@ def _load_cache_item(self, idx: int): return patch_cache def _generate_patches(self, src, **apply_args): + """ + yield patches optionally post-processed by transform. + + Args: + src: a iterable of image patches. + apply_args: other args for `self.patch_transform`. + + """ for patch, *others in src: out_patch = patch if self.patch_transform is not None: @@ -343,7 +370,7 @@ def __iter__(self): class PatchDataset(IterableDataset): """ - returns a patch from an image dataset. + Yields patches from data read from an image dataset. The patches are generated by a user-specified callable `patch_func`, and are optionally post-processed by `transform`. For example, to generate random patch samples from an image dataset: From 32d44128c2a5486932e643130af296f8ec01062a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 10 Nov 2023 21:37:03 +0800 Subject: [PATCH 17/22] fix mypy Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 03b757911d..2f397018e7 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -313,7 +313,7 @@ def _load_cache_item(self, idx: int): patch_cache, other_cache = [], [] for patch, *others in self.patch_iter(item): if self.first_random is not None: - patch = self.patch_transform(patch, end=self.first_random, threading=True) + patch = self.patch_transform(patch, end=self.first_random, threading=True) # type: ignore if self.as_contiguous: patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format) From 13c25da1439ab1313c8fa5cf8adc3d717b64e740 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:36:40 +0800 Subject: [PATCH 18/22] Update monai/data/grid_dataset.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/grid_dataset.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 2f397018e7..4f169a7cea 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -272,17 +272,12 @@ def set_data(self, data: Sequence) -> None: """ self.data = data - def _compute_cache_num(data_len: int): - self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len) - # only compute cache for the unique items of dataset, and record the last index for duplicated items mapping = {self.hash_func(v): i for i, v in enumerate(self.data)} - _compute_cache_num(len(mapping)) + self.cache_num = min(int(self.set_num), int(len(mapping)* self.set_rate), len(mapping)) self._hash_keys = list(mapping)[: self.cache_num] indices = list(mapping.values())[: self.cache_num] - self._cache = self._fill_cache(indices) - return def _fill_cache(self, indices=None) -> list: """ From b4038f21de13ca2ecffa221ab0dbb0d6ab0c6944 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:47:19 +0800 Subject: [PATCH 19/22] Update monai/data/grid_dataset.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/grid_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 4f169a7cea..0a1db06ceb 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -294,10 +294,10 @@ def _fill_cache(self, indices=None) -> list: indices = list(range(self.cache_num)) if self.progress and not has_tqdm: warnings.warn("tqdm is not installed, will not show the caching progress bar.") + + pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v) with ThreadPool(self.num_workers) as p: - if self.progress and has_tqdm: - return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset")) - return list(p.imap(self._load_cache_item, indices)) + return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset")) def _load_cache_item(self, idx: int): """ From 2b5e5dffb5a526b5412ea0184b39025a875f7738 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Nov 2023 02:47:41 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/grid_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 0a1db06ceb..83b8999131 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -294,7 +294,7 @@ def _fill_cache(self, indices=None) -> list: indices = list(range(self.cache_num)) if self.progress and not has_tqdm: warnings.warn("tqdm is not installed, will not show the caching progress bar.") - + pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v) with ThreadPool(self.num_workers) as p: return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset")) From f1f636090310d4242139d557a7957b9d0ef4c667 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 17 Nov 2023 15:38:34 +0800 Subject: [PATCH 21/22] address comments Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 7 +++---- tests/test_grid_dataset.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 83b8999131..ad78dbd3fa 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -274,10 +274,10 @@ def set_data(self, data: Sequence) -> None: # only compute cache for the unique items of dataset, and record the last index for duplicated items mapping = {self.hash_func(v): i for i, v in enumerate(self.data)} - self.cache_num = min(int(self.set_num), int(len(mapping)* self.set_rate), len(mapping)) + self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping)) self._hash_keys = list(mapping)[: self.cache_num] indices = list(mapping.values())[: self.cache_num] - self._cache = self._fill_cache(indices) + self._cache, self._cache_other = zip(*self._fill_cache(indices)) def _fill_cache(self, indices=None) -> list: """ @@ -315,8 +315,7 @@ def _load_cache_item(self, idx: int): if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords other_cache.append(others[0]) patch_cache.append(patch) - self._cache_other.append(other_cache) - return patch_cache + return patch_cache, other_cache def _generate_patches(self, src, **apply_args): """ diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 49f2efe865..d937a5e266 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -180,7 +180,7 @@ def test_loading_dict(self): def test_set_data(self): from monai.transforms import Compose, Lambda, RandLambda - images = [np.arange(16, dtype=float).reshape(1, 4, 4)] + images = [np.arange(2, 18, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] transform = Compose( [Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False From d8e1f426e6f87826a862e4d3d31d0d20c861394b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 17 Nov 2023 16:11:06 +0800 Subject: [PATCH 22/22] fix ci Signed-off-by: KumoLiu --- monai/data/grid_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index ad78dbd3fa..9079032e6f 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -251,6 +251,7 @@ def __init__( self._cache: list | ListProxy = [] self._cache_other: list | ListProxy = [] self.cache = cache + self.first_random: int | None = None if self.patch_transform is not None: self.first_random = self.patch_transform.get_index_of_first( lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) @@ -277,7 +278,7 @@ def set_data(self, data: Sequence) -> None: self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping)) self._hash_keys = list(mapping)[: self.cache_num] indices = list(mapping.values())[: self.cache_num] - self._cache, self._cache_other = zip(*self._fill_cache(indices)) + self._cache, self._cache_other = zip(*self._fill_cache(indices)) # type: ignore def _fill_cache(self, indices=None) -> list: """