From 2f15c17cdde0188a3d377f5256aa95b7a2fa2bfd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 29 Jan 2022 15:52:59 +0800 Subject: [PATCH 1/7] [DLMED] add HashCacheDataset Signed-off-by: Nic Ma --- docs/source/data.rst | 6 ++ monai/data/__init__.py | 1 + monai/data/dataset.py | 34 +++++++++ tests/test_hashcachedataset.py | 126 +++++++++++++++++++++++++++++++++ 4 files changed, 167 insertions(+) create mode 100644 tests/test_hashcachedataset.py diff --git a/docs/source/data.rst b/docs/source/data.rst index 1e6b535b12..d6ed833d87 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -99,6 +99,12 @@ Generic Interfaces :members: :special-members: __getitem__ +`HashCacheDataset` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: HashCacheDataset + :members: + :special-members: __getitem__ + Patch-based dataset ------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index bd49f40273..e2b808f795 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -18,6 +18,7 @@ CSVDataset, Dataset, DatasetFunc, + HashCacheDataset, LMDBDataset, NPZDictItemDataset, PersistentDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 8a42ed5181..f143ac8dd2 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1348,3 +1348,37 @@ def __init__( dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) super().__init__(data=data, transform=transform) + + +class HashCacheDataset(CacheDataset): + """ + Extend from `CacheDataset` to support only caching unique items in the datset. + It computes hash value of input data as the key to save cache, if key exists, avoid saving duplicated content. + Can help save memory when the dataset has duplicated items or augmented dataset. + The `cache_num` or `cache_rate` are computed against only on the unique items of the dataset. + + Args: + data: input data to load and transform to generate dataset for model. + hash_func: a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. + kwargs: other arguments of `CacheDataset` except for `data`. + + """ + + def __init__(self, data: Sequence, hash_func: Callable[..., bytes] = pickle_hashing, **kwargs) -> None: + self.hash_func = hash_func + mapping = {self.hash_func(v): v for v in data} + # only compute cache for the unique items of dataset + super().__init__(data=list(mapping.values()), **kwargs) + self._cache_keys = list(mapping)[: self.cache_num] + self.data = data + + def _transform(self, index: int): + key = self.hash_func(self.data[index]) + if key in self._cache_keys: + # if existing in cache, get the index + index = self._cache_keys.index(key) + return super()._transform(index=index) + + def set_data(self, _: Sequence): + raise NotImplementedError("`set_data` at runtime is not supported in `HashCacheDataset`.") diff --git a/tests/test_hashcachedataset.py b/tests/test_hashcachedataset.py new file mode 100644 index 0000000000..aae9a008c8 --- /dev/null +++ b/tests/test_hashcachedataset.py @@ -0,0 +1,126 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.data import DataLoader, HashCacheDataset +from monai.transforms import Compose, LoadImaged, ThreadUnsafe, Transform +from monai.utils.module import pytorch_after + +TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label"])]), (128, 128, 128)] + +TEST_CASE_2 = [None, (128, 128, 128)] + +TEST_DS = [] +for c in (0, 1, 2): + for l in (0, 1, 2): + TEST_DS.append([False, c, 0 if sys.platform in ("darwin", "win32") else l]) + if sys.platform not in ("darwin", "win32"): + # persistent_workers need l > 0 + for l in (1, 2): + TEST_DS.append([True, c, l]) + + +class TestCacheDataset(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + }, + # duplicated data for augmentation + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + }, + ] + dataset = HashCacheDataset(data=test_data, transform=transform, cache_rate=1.0, num_workers=2) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 2) + data1 = dataset[0] + data2 = dataset[1] + data3 = dataset[-1] + # test slice indices + data4 = dataset[0:-1] + self.assertEqual(len(data4), 2) + + if transform is None: + self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) + self.assertEqual(data3["image"], os.path.join(tempdir, "test_image2.nii.gz")) + else: + self.assertTupleEqual(data1["image"].shape, expected_shape) + self.assertTupleEqual(data2["label"].shape, expected_shape) + self.assertTupleEqual(data3["image"].shape, expected_shape) + for d in data4: + self.assertTupleEqual(d["image"].shape, expected_shape) + + +class _StatefulTransform(Transform, ThreadUnsafe): + """ + A transform with an internal state. + The state is changing at each call. + """ + + def __init__(self): + self.property = 1 + + def __call__(self, data): + self.property = self.property + 1 + return data * 100 + self.property + + +class TestDataLoader(unittest.TestCase): + @parameterized.expand(TEST_DS) + def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): + expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] + _kwg = {"persistent_workers": persistent_workers} if pytorch_after(1, 8) else {} + data_list = list(range(1, 11)) + dataset = HashCacheDataset( + data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False + ) + self.assertListEqual(expected, list(dataset)) + loader = DataLoader( + HashCacheDataset( + data=data_list, + transform=_StatefulTransform(), + cache_rate=1.0, + num_workers=cache_workers, + progress=False, + ), + batch_size=1, + num_workers=loader_workers, + **_kwg, + ) + self.assertListEqual(expected, [y.item() for y in loader]) + self.assertListEqual(expected, [y.item() for y in loader]) + + +if __name__ == "__main__": + unittest.main() From 5816dec0f85b632e35b4fc26862c4bd0ec5ce9d4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 29 Jan 2022 16:37:35 +0800 Subject: [PATCH 2/7] [DLMED] add more test Signed-off-by: Nic Ma --- tests/test_hashcachedataset.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/test_hashcachedataset.py b/tests/test_hashcachedataset.py index aae9a008c8..43c6cbd7f2 100644 --- a/tests/test_hashcachedataset.py +++ b/tests/test_hashcachedataset.py @@ -45,6 +45,8 @@ def test_shape(self, transform, expected_shape): nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image3.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label3.nii.gz")) test_data = [ { "image": os.path.join(tempdir, "test_image1.nii.gz"), @@ -59,21 +61,30 @@ def test_shape(self, transform, expected_shape): "image": os.path.join(tempdir, "test_image2.nii.gz"), "label": os.path.join(tempdir, "test_label2.nii.gz"), }, + { + "image": os.path.join(tempdir, "test_image3.nii.gz"), + "label": os.path.join(tempdir, "test_label3.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image3.nii.gz"), + "label": os.path.join(tempdir, "test_label3.nii.gz"), + }, ] - dataset = HashCacheDataset(data=test_data, transform=transform, cache_rate=1.0, num_workers=2) + dataset = HashCacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2) + self.assertEqual(len(dataset), 5) # ensure no duplicated cache content - self.assertEqual(len(dataset._cache), 2) + self.assertEqual(len(dataset._cache), 3) data1 = dataset[0] data2 = dataset[1] data3 = dataset[-1] # test slice indices data4 = dataset[0:-1] - self.assertEqual(len(data4), 2) + self.assertEqual(len(data4), 4) if transform is None: self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data3["image"], os.path.join(tempdir, "test_image2.nii.gz")) + self.assertEqual(data3["image"], os.path.join(tempdir, "test_image3.nii.gz")) else: self.assertTupleEqual(data1["image"].shape, expected_shape) self.assertTupleEqual(data2["label"].shape, expected_shape) From ffdc36657cb2ca78b6e0dc5f54b001cdda17a950 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 29 Jan 2022 16:51:09 +0800 Subject: [PATCH 3/7] [DLMED] skip min test Signed-off-by: Nic Ma --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 783ab370c1..6f0e9a0ca7 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -155,6 +155,7 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_prepare_batch_default_dist", + "test_hashcachedataset" ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From edbe6ac77436fa5844ff2f3e5695e59fec386a20 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 29 Jan 2022 16:55:57 +0800 Subject: [PATCH 4/7] [DLMED] fix flake8 Signed-off-by: Nic Ma --- tests/min_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/min_tests.py b/tests/min_tests.py index 6f0e9a0ca7..8d8dff7f19 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -155,7 +155,7 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_prepare_batch_default_dist", - "test_hashcachedataset" + "test_hashcachedataset", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From b57a75567b8ed33db3dd99ade53e489a2936c21a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 30 Jan 2022 11:58:52 +0800 Subject: [PATCH 5/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- docs/source/data.rst | 6 -- monai/data/__init__.py | 1 - monai/data/dataset.py | 82 +++++++++----------- tests/min_tests.py | 1 - tests/test_cachedataset.py | 70 ++++++++++++----- tests/test_hashcachedataset.py | 137 --------------------------------- 6 files changed, 85 insertions(+), 212 deletions(-) delete mode 100644 tests/test_hashcachedataset.py diff --git a/docs/source/data.rst b/docs/source/data.rst index d6ed833d87..1e6b535b12 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -99,12 +99,6 @@ Generic Interfaces :members: :special-members: __getitem__ -`HashCacheDataset` -~~~~~~~~~~~~~~~~~~ -.. autoclass:: HashCacheDataset - :members: - :special-members: __getitem__ - Patch-based dataset ------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e2b808f795..bd49f40273 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -18,7 +18,6 @@ CSVDataset, Dataset, DatasetFunc, - HashCacheDataset, LMDBDataset, NPZDictItemDataset, PersistentDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index f143ac8dd2..8c66fad241 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -676,6 +676,8 @@ def __init__( progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, + hash_as_key: bool = False, + hash_func: Callable[..., bytes] = pickle_hashing, ) -> None: """ Args: @@ -695,19 +697,30 @@ def __init__( may set `copy=False` for better performance. as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. it may help improve the performance of following logic. + hash_as_key: whether to compute hash value of input data as the key to save cache, + if key exists, avoid saving duplicated content. it can help save memory when + the dataset has duplicated items or augmented dataset. + hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) + self.set_num = cache_num + self.set_rate = cache_rate self.progress = progress self.copy_cache = copy_cache self.as_contiguous = as_contiguous - self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) + self.hash_as_key = hash_as_key + self.hash_func = hash_func self.num_workers = num_workers if self.num_workers is not None: self.num_workers = max(int(self.num_workers), 1) - self._cache: List = self._fill_cache() + self.cache_num = 0 + self._cache_keys: List = [] + self._cache: List = [] + self.set_data(data) def set_data(self, data: Sequence): """ @@ -718,8 +731,16 @@ def set_data(self, data: Sequence): generated cache content. """ - self.data = data + if self.hash_as_key: + # only compute cache for the unique items of dataset + mapping = {self.hash_func(v): v for v in data} + self._cache_keys = list(mapping) + self.data = list(mapping.values()) + else: + self.data = data + self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data)) self._cache = self._fill_cache() + self.data = data def _fill_cache(self) -> List: if self.cache_num <= 0: @@ -754,6 +775,12 @@ def _load_cache_item(self, idx: int): return item def _transform(self, index: int): + if self.hash_as_key: + key = self.hash_func(self.data[index]) + if key in self._cache_keys[: self.cache_num]: + # if existing in cache, get the index + index = self._cache_keys.index(key) + if index % len(self) >= self.cache_num: # support negative index # no cache for this index, execute all the transforms directly return super()._transform(index) @@ -862,10 +889,14 @@ def __init__( ) -> None: if shuffle: self.set_random_state(seed=seed) - data = copy(data) - self.randomize(data) self.shuffle = shuffle + self._start_pos: int = 0 + self._update_lock: threading.Lock = threading.Lock() + self._round: int = 1 + self._replace_done: bool = False + self._replace_mgr: Optional[threading.Thread] = None + super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous) if self._cache is None: self._cache = self._fill_cache() @@ -884,13 +915,6 @@ def __init__( self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num) self._replacements: List[Any] = [None for _ in range(self._replace_num)] self._replace_data_idx: List[int] = list(range(self._replace_num)) - - self._start_pos: int = 0 - self._update_lock: threading.Lock = threading.Lock() - self._round: int = 1 - self._replace_done: bool = False - self._replace_mgr: Optional[threading.Thread] = None - self._compute_data_idx() def set_data(self, data: Sequence): @@ -1348,37 +1372,3 @@ def __init__( dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs ) super().__init__(data=data, transform=transform) - - -class HashCacheDataset(CacheDataset): - """ - Extend from `CacheDataset` to support only caching unique items in the datset. - It computes hash value of input data as the key to save cache, if key exists, avoid saving duplicated content. - Can help save memory when the dataset has duplicated items or augmented dataset. - The `cache_num` or `cache_rate` are computed against only on the unique items of the dataset. - - Args: - data: input data to load and transform to generate dataset for model. - hash_func: a callable to compute hash from data items to be cached. - defaults to `monai.data.utils.pickle_hashing`. - kwargs: other arguments of `CacheDataset` except for `data`. - - """ - - def __init__(self, data: Sequence, hash_func: Callable[..., bytes] = pickle_hashing, **kwargs) -> None: - self.hash_func = hash_func - mapping = {self.hash_func(v): v for v in data} - # only compute cache for the unique items of dataset - super().__init__(data=list(mapping.values()), **kwargs) - self._cache_keys = list(mapping)[: self.cache_num] - self.data = data - - def _transform(self, index: int): - key = self.hash_func(self.data[index]) - if key in self._cache_keys: - # if existing in cache, get the index - index = self._cache_keys.index(key) - return super()._transform(index=index) - - def set_data(self, _: Sequence): - raise NotImplementedError("`set_data` at runtime is not supported in `HashCacheDataset`.") diff --git a/tests/min_tests.py b/tests/min_tests.py index 8d8dff7f19..783ab370c1 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -155,7 +155,6 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_prepare_batch_default_dist", - "test_hashcachedataset", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index a742f5889a..7227f53e04 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -42,24 +42,12 @@ class TestCacheDataset(unittest.TestCase): def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] + test_data = [] + for i in ["1", "2"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5, as_contiguous=True) data1 = dataset[0] data2 = dataset[1] @@ -68,9 +56,9 @@ def test_shape(self, transform, expected_shape): self.assertEqual(len(data3), 1) if transform is None: - self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data4["image"], os.path.join(tempdir, "test_image2.nii.gz")) + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data4["image"], os.path.join(tempdir, "image2.nii.gz")) else: self.assertTupleEqual(data1["image"].shape, expected_shape) self.assertTupleEqual(data1["label"].shape, expected_shape) @@ -195,6 +183,46 @@ def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_hash_as_key(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + test_data = [] + for i in ["1", "2", "2", "3", "3"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + + dataset = CacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2, hash_as_key=True) + self.assertEqual(len(dataset), 5) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 3) + self.assertEqual(dataset.cache_num, 3) + data1 = dataset[0] + data2 = dataset[1] + data3 = dataset[-1] + # test slice indices + data4 = dataset[0:-1] + self.assertEqual(len(data4), 4) + + if transform is None: + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data3["image"], os.path.join(tempdir, "image3.nii.gz")) + else: + self.assertTupleEqual(data1["image"].shape, expected_shape) + self.assertTupleEqual(data2["label"].shape, expected_shape) + self.assertTupleEqual(data3["image"].shape, expected_shape) + for d in data4: + self.assertTupleEqual(d["image"].shape, expected_shape) + + test_data2 = test_data[:3] + dataset.set_data(data=test_data2) + self.assertEqual(len(dataset), 3) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 2) + self.assertEqual(dataset.cache_num, 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_hashcachedataset.py b/tests/test_hashcachedataset.py deleted file mode 100644 index 43c6cbd7f2..0000000000 --- a/tests/test_hashcachedataset.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -import tempfile -import unittest - -import nibabel as nib -import numpy as np -from parameterized import parameterized - -from monai.data import DataLoader, HashCacheDataset -from monai.transforms import Compose, LoadImaged, ThreadUnsafe, Transform -from monai.utils.module import pytorch_after - -TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label"])]), (128, 128, 128)] - -TEST_CASE_2 = [None, (128, 128, 128)] - -TEST_DS = [] -for c in (0, 1, 2): - for l in (0, 1, 2): - TEST_DS.append([False, c, 0 if sys.platform in ("darwin", "win32") else l]) - if sys.platform not in ("darwin", "win32"): - # persistent_workers need l > 0 - for l in (1, 2): - TEST_DS.append([True, c, l]) - - -class TestCacheDataset(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, transform, expected_shape): - test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - with tempfile.TemporaryDirectory() as tempdir: - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image3.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label3.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - }, - # duplicated data for augmentation - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image3.nii.gz"), - "label": os.path.join(tempdir, "test_label3.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image3.nii.gz"), - "label": os.path.join(tempdir, "test_label3.nii.gz"), - }, - ] - dataset = HashCacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2) - self.assertEqual(len(dataset), 5) - # ensure no duplicated cache content - self.assertEqual(len(dataset._cache), 3) - data1 = dataset[0] - data2 = dataset[1] - data3 = dataset[-1] - # test slice indices - data4 = dataset[0:-1] - self.assertEqual(len(data4), 4) - - if transform is None: - self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data3["image"], os.path.join(tempdir, "test_image3.nii.gz")) - else: - self.assertTupleEqual(data1["image"].shape, expected_shape) - self.assertTupleEqual(data2["label"].shape, expected_shape) - self.assertTupleEqual(data3["image"].shape, expected_shape) - for d in data4: - self.assertTupleEqual(d["image"].shape, expected_shape) - - -class _StatefulTransform(Transform, ThreadUnsafe): - """ - A transform with an internal state. - The state is changing at each call. - """ - - def __init__(self): - self.property = 1 - - def __call__(self, data): - self.property = self.property + 1 - return data * 100 + self.property - - -class TestDataLoader(unittest.TestCase): - @parameterized.expand(TEST_DS) - def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): - expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002] - _kwg = {"persistent_workers": persistent_workers} if pytorch_after(1, 8) else {} - data_list = list(range(1, 11)) - dataset = HashCacheDataset( - data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False - ) - self.assertListEqual(expected, list(dataset)) - loader = DataLoader( - HashCacheDataset( - data=data_list, - transform=_StatefulTransform(), - cache_rate=1.0, - num_workers=cache_workers, - progress=False, - ), - batch_size=1, - num_workers=loader_workers, - **_kwg, - ) - self.assertListEqual(expected, [y.item() for y in loader]) - self.assertListEqual(expected, [y.item() for y in loader]) - - -if __name__ == "__main__": - unittest.main() From f9af64d825b40e7a36090f4161a38789a184e4e4 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sun, 30 Jan 2022 04:07:55 +0000 Subject: [PATCH 6/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/config/deviceconfig.py | 8 ++++---- monai/data/dataset.py | 2 +- monai/data/dataset_summary.py | 2 +- monai/handlers/parameter_scheduler.py | 2 +- monai/losses/image_dissimilarity.py | 6 +++--- monai/networks/blocks/selfattention.py | 2 +- monai/networks/blocks/upsample.py | 4 ++-- monai/networks/blocks/warp.py | 2 +- monai/networks/layers/convutils.py | 2 +- monai/networks/nets/dints.py | 10 +++++----- monai/networks/nets/highresnet.py | 2 +- monai/networks/nets/regunet.py | 6 +++--- monai/networks/nets/segresnet.py | 6 +++--- monai/networks/utils.py | 2 +- monai/transforms/intensity/array.py | 4 ++-- monai/transforms/smooth_field/array.py | 2 +- monai/transforms/utils_create_transform_ims.py | 2 +- tests/test_lmdbdataset.py | 6 +++--- tests/test_tile_on_grid.py | 2 +- tests/test_tile_on_grid_dict.py | 2 +- 20 files changed, 37 insertions(+), 37 deletions(-) diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index 91b944bde5..fd7ca572e6 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -161,9 +161,9 @@ def get_system_info() -> OrderedDict: ), ) mem = psutil.virtual_memory() - _dict_append(output, "Total physical memory (GB)", lambda: round(mem.total / 1024 ** 3, 1)) - _dict_append(output, "Available memory (GB)", lambda: round(mem.available / 1024 ** 3, 1)) - _dict_append(output, "Used memory (GB)", lambda: round(mem.used / 1024 ** 3, 1)) + _dict_append(output, "Total physical memory (GB)", lambda: round(mem.total / 1024**3, 1)) + _dict_append(output, "Available memory (GB)", lambda: round(mem.available / 1024**3, 1)) + _dict_append(output, "Used memory (GB)", lambda: round(mem.used / 1024**3, 1)) return output @@ -209,7 +209,7 @@ def get_gpu_info() -> OrderedDict: _dict_append(output, f"GPU {gpu} Is integrated", lambda: bool(gpu_info.is_integrated)) _dict_append(output, f"GPU {gpu} Is multi GPU board", lambda: bool(gpu_info.is_multi_gpu_board)) _dict_append(output, f"GPU {gpu} Multi processor count", lambda: gpu_info.multi_processor_count) - _dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024 ** 3, 1)) + _dict_append(output, f"GPU {gpu} Total memory (GB)", lambda: round(gpu_info.total_memory / 1024**3, 1)) _dict_append(output, f"GPU {gpu} CUDA capability (maj.min)", lambda: f"{gpu_info.major}.{gpu_info.minor}") return output diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 8c66fad241..f9804629cc 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -513,7 +513,7 @@ def __init__( self.db_file = self.cache_dir / f"{db_name}.lmdb" self.lmdb_kwargs = lmdb_kwargs or {} if not self.lmdb_kwargs.get("map_size", 0): - self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size + self.lmdb_kwargs["map_size"] = 1024**4 # default map_size # lmdb is single-writer multi-reader by default # the cache is created without multi-threading self._read_env = None diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 2b4df4ebbf..956e038569 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -150,7 +150,7 @@ def calculate_statistics(self, foreground_threshold: int = 0): self.data_max, self.data_min = max(voxel_max), min(voxel_min) self.data_mean = (voxel_sum / voxel_ct).item() - self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean ** 2)).item() + self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean**2)).item() def calculate_percentiles( self, diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index c0e18edcd0..67c51fd351 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -134,7 +134,7 @@ def _exponential(initial_value: float, gamma: float, current_step: int) -> float Returns: float: new parameter value """ - return initial_value * gamma ** current_step + return initial_value * gamma**current_step @staticmethod def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float: diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index b527522cd7..a06f6fb5cd 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -126,7 +126,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") - t2, p2, tp = target ** 2, pred ** 2, target * pred + t2, p2, tp = target**2, pred**2, target * pred kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) # sum over kernel t_sum = separable_filtering(target, kernels=[kernel.to(pred)] * self.ndim) @@ -217,7 +217,7 @@ def __init__( self.num_bins = num_bins self.kernel_type = kernel_type if self.kernel_type == "gaussian": - self.preterm = 1 / (2 * sigma ** 2) + self.preterm = 1 / (2 * sigma**2) self.bin_centers = bin_centers[None, None, ...] self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) @@ -280,7 +280,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torc weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: weight = ( - weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix < 1) / 6 + weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6 ) weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 4a86cd84bc..cf837c5a6f 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -46,7 +46,7 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 def forward(self, x): q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index c72d1bc518..fa3929df20 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -219,7 +219,7 @@ def __init__( out_channels = out_channels or in_channels if not out_channels: raise ValueError("in_channels need to be specified.") - conv_out_channels = out_channels * (scale_factor ** self.dimensions) + conv_out_channels = out_channels * (scale_factor**self.dimensions) self.conv_block = Conv[Conv.CONV, self.dimensions]( in_channels=in_channels, out_channels=conv_out_channels, kernel_size=3, stride=1, padding=1, bias=bias ) @@ -247,7 +247,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...). """ x = self.conv_block(x) - if x.shape[1] % (self.scale_factor ** self.dimensions) != 0: + if x.shape[1] % (self.scale_factor**self.dimensions) != 0: raise ValueError( f"Number of channels after `conv_block` ({x.shape[1]}) must be evenly " "divisible by scale_factor ** dimensions " diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 9fdaab0a48..5b925258b6 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -150,7 +150,7 @@ def forward(self, dvf): Returns: a dense displacement field """ - ddf: torch.Tensor = dvf / (2 ** self.num_steps) + ddf: torch.Tensor = dvf / (2**self.num_steps) for _ in range(self.num_steps): ddf = ddf + self.warp_layer(image=ddf, ddf=ddf) return ddf diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py index 5efb6e792f..1e9ce954e8 100644 --- a/monai/networks/layers/convutils.py +++ b/monai/networks/layers/convutils.py @@ -115,7 +115,7 @@ def gaussian_1d( out = out.clamp(min=0) elif approx.lower() == "sampled": x = torch.arange(-tail, tail + 1, dtype=torch.float, device=sigma.device) - out = torch.exp(-0.5 / (sigma * sigma) * x ** 2) + out = torch.exp(-0.5 / (sigma * sigma) * x**2) if not normalize: # compute the normalizer out = out / (2.5066282 * sigma) elif approx.lower() == "scalespace": diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py index c024d6e0f1..a4aaf32eed 100644 --- a/monai/networks/nets/dints.py +++ b/monai/networks/nets/dints.py @@ -124,7 +124,7 @@ def __init__(self, in_channel: int, out_channel: int, spatial_dims: int = 3): # s0 is upsampled 2x from s1, representing feature sizes at two resolutions. # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization) # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims) - self.ram_cost = in_channel / out_channel * 2 ** self._spatial_dims + 3 + self.ram_cost = in_channel / out_channel * 2**self._spatial_dims + 3 class MixedOp(nn.Module): @@ -330,7 +330,7 @@ def __init__( # define downsample stems before DiNTS search if use_downsample: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample(scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True), + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -373,7 +373,7 @@ def __init__( else: self.stem_down[str(res_idx)] = StemTS( - nn.Upsample(scale_factor=1 / (2 ** res_idx), mode=mode, align_corners=True), + nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True), conv_type( in_channels=in_channels, out_channels=self.filter_nums[res_idx], @@ -789,7 +789,7 @@ def get_ram_cost_usage(self, in_size, full: bool = False): image_size = np.array(in_size[-self._spatial_dims :]) sizes = [] for res_idx in range(self.num_depths): - sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2 ** res_idx)).prod()) + sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod()) sizes = torch.tensor(sizes).to(torch.float32).to(self.device) / (2 ** (int(self.use_downsample))) probs_a, arch_code_prob_a = self.get_prob_a(child=False) cell_prob = F.softmax(self.log_alpha_c, dim=-1) @@ -807,7 +807,7 @@ def get_ram_cost_usage(self, in_size, full: bool = False): * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum()) * sizes[self.arch_code2out[path_idx]] ) - return usage * 32 / 8 / 1024 ** 2 + return usage * 32 / 8 / 1024**2 def get_topology_entropy(self, probs): """ diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index 95c0c758af..891a65e67b 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -168,7 +168,7 @@ def __init__( # residual blocks for (idx, params) in enumerate(layer_params[1:-2]): # res blocks except the 1st and last two conv layers. _in_chns, _out_chns = _out_chns, params["n_features"] - _dilation = 2 ** idx + _dilation = 2**idx for _ in range(params["repeat"]): blocks.append( HighResBlock( diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index 8524563faa..6776c7ce9e 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -92,7 +92,7 @@ def __init__( raise AssertionError self.encode_kernel_sizes: List[int] = encode_kernel_sizes - self.num_channels = [self.num_channel_initial * (2 ** d) for d in range(self.depth + 1)] + self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)] self.min_extract_level = min(self.extract_levels) # init layers @@ -310,14 +310,14 @@ def __init__( encode_kernel_sizes: Union[int, List[int]] = 3, ): for size in image_size: - if size % (2 ** depth) != 0: + if size % (2**depth) != 0: raise ValueError( f"given depth {depth}, " f"all input spatial dimension must be divisible by {2 ** depth}, " f"got input of size {image_size}" ) self.image_size = image_size - self.decode_size = [size // (2 ** depth) for size in image_size] + self.decode_size = [size // (2**depth) for size in image_size] super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py index d2c45dd3a3..299f1ca811 100644 --- a/monai/networks/nets/segresnet.py +++ b/monai/networks/nets/segresnet.py @@ -102,7 +102,7 @@ def _make_down_layers(self): down_layers = nn.ModuleList() blocks_down, spatial_dims, filters, norm = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm) for i in range(len(blocks_down)): - layer_in_channels = filters * 2 ** i + layer_in_channels = filters * 2**i pre_conv = ( get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2) if i > 0 @@ -299,12 +299,12 @@ def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor): if self.vae_estimate_std: z_sigma = self.vae_fc2(x_vae) z_sigma = F.softplus(z_sigma) - vae_reg_loss = 0.5 * torch.mean(z_mean ** 2 + z_sigma ** 2 - torch.log(1e-8 + z_sigma ** 2) - 1) + vae_reg_loss = 0.5 * torch.mean(z_mean**2 + z_sigma**2 - torch.log(1e-8 + z_sigma**2) - 1) x_vae = z_mean + z_sigma * z_mean_rand else: z_sigma = self.vae_default_std - vae_reg_loss = torch.mean(z_mean ** 2) + vae_reg_loss = torch.mean(z_mean**2) x_vae = z_mean + z_sigma * z_mean_rand diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a4ca0a6fd5..47cc838fdc 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -270,7 +270,7 @@ def pixelshuffle( dim, factor = spatial_dims, scale_factor input_size = list(x.size()) batch_size, channels = input_size[:2] - scale_divisor = factor ** dim + scale_divisor = factor**dim if channels % scale_divisor != 0: raise ValueError( diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6657950eae..dab2789425 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -182,9 +182,9 @@ def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float): if isinstance(img, torch.Tensor): n1 = torch.tensor(self._noise1, device=img.device) n2 = torch.tensor(self._noise2, device=img.device) - return torch.sqrt((img + n1) ** 2 + n2 ** 2) + return torch.sqrt((img + n1) ** 2 + n2**2) - return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) + return np.sqrt((img + self._noise1) ** 2 + self._noise2**2) def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py index 356b0d167f..2a8cf6c7f1 100644 --- a/monai/transforms/smooth_field/array.py +++ b/monai/transforms/smooth_field/array.py @@ -232,7 +232,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen # everything below here is to be computed using the destination type (numpy, tensor, etc.) img = (img - img_min) / (img_rng + 1e-10) # rescale to unit values - img = img ** rfield # contrast is changed by raising image data to a power, in this case the field + img = img**rfield # contrast is changed by raising image data to a power, in this case the field out = (img * img_rng) + img_min # rescale back to the original image value range diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index ab282d5332..977c386121 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -382,7 +382,7 @@ def get_images(data, is_label=False): # we might need to panel the images. this happens if a transform produces e.g. 4 output images. # In this case, we create a 2-by-2 grid from them. Output will be a list containing n_orthog_views, # each element being either the image (if num_samples is 1) or the panelled image. - nrows = int(np.floor(num_samples ** 0.5)) + nrows = int(np.floor(num_samples**0.5)) for view in range(num_orthog_views): result = np.asarray([d[view] for d in data]) nindex, height, width = result.shape diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index b624e5c4e3..33f27ee4bc 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -57,7 +57,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"pickle_protocol": 2, "lmdb_kwargs": {"map_size": 100 * 1024 ** 2}}, + {"pickle_protocol": 2, "lmdb_kwargs": {"map_size": 100 * 1024**2}}, ] TEST_CASE_6 = [ @@ -66,7 +66,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"db_name": "testdb", "lmdb_kwargs": {"map_size": 100 * 1024 ** 2}}, + {"db_name": "testdb", "lmdb_kwargs": {"map_size": 100 * 1024**2}}, ] TEST_CASE_7 = [ @@ -75,7 +75,7 @@ SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), ], (128, 128, 128), - {"db_name": "testdb", "lmdb_kwargs": {"map_size": 2 * 1024 ** 2}}, + {"db_name": "testdb", "lmdb_kwargs": {"map_size": 2 * 1024**2}}, ] diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py index 08fb5d96fe..d2d99654b7 100644 --- a/tests/test_tile_on_grid.py +++ b/tests/test_tile_on_grid.py @@ -107,7 +107,7 @@ def make_image( tiles = np.stack(tiles_list, axis=0) # type: ignore - if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count**2: tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] return imlarge, tiles diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py index 9f1d67ac29..7491582ce7 100644 --- a/tests/test_tile_on_grid_dict.py +++ b/tests/test_tile_on_grid_dict.py @@ -116,7 +116,7 @@ def make_image( tiles = np.stack(tiles_list, axis=0) # type: ignore - if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count**2: tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] return imlarge, tiles From 5b4715da9242e563e34e3691267f66133bf4df95 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 31 Jan 2022 23:50:31 +0800 Subject: [PATCH 7/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/dataset.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index f9804629cc..426f9856fe 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -707,8 +707,8 @@ def __init__( if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) - self.set_num = cache_num - self.set_rate = cache_rate + self.set_num = cache_num # tracking the user-provided `cache_num` option + self.set_rate = cache_rate # tracking the user-provided `cache_rate` option self.progress = progress self.copy_cache = copy_cache self.as_contiguous = as_contiguous @@ -718,8 +718,7 @@ def __init__( if self.num_workers is not None: self.num_workers = max(int(self.num_workers), 1) self.cache_num = 0 - self._cache_keys: List = [] - self._cache: List = [] + self._cache: Union[List, Dict] = [] self.set_data(data) def set_data(self, data: Sequence): @@ -731,16 +730,21 @@ def set_data(self, data: Sequence): generated cache content. """ + + def _compute_cache(): + self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data)) + return self._fill_cache() + if self.hash_as_key: # only compute cache for the unique items of dataset mapping = {self.hash_func(v): v for v in data} - self._cache_keys = list(mapping) self.data = list(mapping.values()) + cache_ = _compute_cache() + self._cache = dict(zip(list(mapping)[: self.cache_num], cache_)) + self.data = data else: self.data = data - self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data)) - self._cache = self._fill_cache() - self.data = data + self._cache = _compute_cache() def _fill_cache(self) -> List: if self.cache_num <= 0: @@ -775,20 +779,21 @@ def _load_cache_item(self, idx: int): return item def _transform(self, index: int): + index_: Any = index if self.hash_as_key: key = self.hash_func(self.data[index]) - if key in self._cache_keys[: self.cache_num]: + if key in self._cache: # if existing in cache, get the index - index = self._cache_keys.index(key) + index_ = key # if using hash as cache keys, set the key - if index % len(self) >= self.cache_num: # support negative index + if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index # no cache for this index, execute all the transforms directly - return super()._transform(index) + return super()._transform(index_) # load data from cache and execute from the first random transform start_run = False if self._cache is None: self._cache = self._fill_cache() - data = self._cache[index] + data = self._cache[index_] if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: