From 08ea9d1fe702b6be5ead2561ad1151d827847987 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 7 Dec 2021 17:24:52 +0800 Subject: [PATCH 1/8] [DLMED] add dataset generator Signed-off-by: Nic Ma --- monai/data/dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index f0416a6bdb..b8e8d8c799 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -97,6 +97,16 @@ def __getitem__(self, index: Union[int, slice, Sequence[int]]): return self._transform(index) +class DatasetGenerator(Dataset): + def __init__(self, transform: Callable) -> None: + self.transform = transform + super().__init__(self.reset(), transform=None) + + def reset(self, transform: Optional[Callable] = None) -> Sequence: + self.data = (self.transform if transform is None else transform)() + return self.data + + class PersistentDataset(Dataset): """ Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data, From b38ac98e8ea4fd0207069acbcb2ba204ecfcfed2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 8 Dec 2021 19:09:51 +0800 Subject: [PATCH 2/8] [DLMED] add DatasetGenerator Signed-off-by: Nic Ma --- docs/source/data.rst | 6 ++++ monai/data/__init__.py | 1 + monai/data/dataset.py | 43 ++++++++++++++++++++++++--- tests/test_dataset_generator.py | 52 +++++++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 4 deletions(-) create mode 100644 tests/test_dataset_generator.py diff --git a/docs/source/data.rst b/docs/source/data.rst index 0ab64edb7b..54a5980744 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -21,6 +21,12 @@ Generic Interfaces :members: :special-members: __next__ +`DatasetGenerator` +~~~~~~~~~~~~~~~~~~ +.. autoclass:: DatasetGenerator + :members: + :special-members: __next__ + `ShuffleBuffer` ~~~~~~~~~~~~~~~ .. autoclass:: ShuffleBuffer diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e7fa2b3107..030aaf1f3a 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -17,6 +17,7 @@ CacheNTransDataset, CSVDataset, Dataset, + DatasetGenerator, LMDBDataset, NPZDictItemDataset, PersistentDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index b8e8d8c799..9046a400c7 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -98,12 +98,47 @@ def __getitem__(self, index: Union[int, slice, Sequence[int]]): class DatasetGenerator(Dataset): - def __init__(self, transform: Callable) -> None: - self.transform = transform + """ + Generator to provide dataset items with specified `func`. + It can be used to load / fetch the basic dataset items, like the list of `image, label` paths. + Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc. + Usage example:: + + data_list = DatasetGenerator( + func=monai.data.load_decathlon_datalist, + data_list_file_path="path to file", + data_list_key="validation", + base_dir="path to base dir", + ) + # partition dataset for every rank + data_partition = DatasetGenerator( + func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()], + data=data_list, + num_partitions=torch.distributed.get_world_size(), + ) + dataset = Dataset(data=data_partition, transform=transforms) + + Args: + func: callable function to generate dataset items. + kwargs: arguments for the `func`. + + """ + + def __init__(self, func: Callable, **kwargs) -> None: + self.func = func + self.kwargs = kwargs super().__init__(self.reset(), transform=None) - def reset(self, transform: Optional[Callable] = None) -> Sequence: - self.data = (self.transform if transform is None else transform)() + def reset(self, func: Optional[Callable] = None, **kwargs) -> Sequence: + """ + Reset the dataset items with specified `func`. + + Args: + func: if not None, execute the `func` with specified `kwargs`, default to `self.func`. + + """ + self.data = self.func(**self.kwargs) if func is None else func(**kwargs) + return self.data diff --git a/tests/test_dataset_generator.py b/tests/test_dataset_generator.py new file mode 100644 index 0000000000..7096289c44 --- /dev/null +++ b/tests/test_dataset_generator.py @@ -0,0 +1,52 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest + +from monai.data import Dataset, DatasetGenerator, load_decathlon_datalist, partition_dataset + + +class TestDatasetGenerator(unittest.TestCase): + def test_seg_values(self): + with tempfile.TemporaryDirectory() as tempdir: + # prepare test datalist file + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "training": [ + {"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"}, + {"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"}, + ], + "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + + data_list = DatasetGenerator( + func=load_decathlon_datalist, data_list_file_path=file_path, data_list_key="training", base_dir=tempdir + ) + # partition dataset for train / validation + data_partition = DatasetGenerator( + func=lambda **kwargs: partition_dataset(**kwargs)[0], data=data_list, num_partitions=2 + ) + dataset = Dataset(data=data_partition, transform=None) + self.assertEqual(dataset[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(dataset[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) + + +if __name__ == "__main__": + unittest.main() From 0fe1fc70a514dd30bdd3564363f2fb8b72682d74 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 9 Dec 2021 00:13:34 +0800 Subject: [PATCH 3/8] [DLMED] update according to comments Signed-off-by: Nic Ma --- docs/source/data.rst | 6 ++--- monai/data/__init__.py | 2 +- monai/data/dataset.py | 26 ++++++++++++------- ...aset_generator.py => test_dataset_func.py} | 12 ++++----- 4 files changed, 26 insertions(+), 20 deletions(-) rename tests/{test_dataset_generator.py => test_dataset_func.py} (80%) diff --git a/docs/source/data.rst b/docs/source/data.rst index 54a5980744..e8c68de853 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -21,9 +21,9 @@ Generic Interfaces :members: :special-members: __next__ -`DatasetGenerator` -~~~~~~~~~~~~~~~~~~ -.. autoclass:: DatasetGenerator +`DatasetFunc` +~~~~~~~~~~~~~ +.. autoclass:: DatasetFunc :members: :special-members: __next__ diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 030aaf1f3a..b12a307663 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -17,7 +17,7 @@ CacheNTransDataset, CSVDataset, Dataset, - DatasetGenerator, + DatasetFunc, LMDBDataset, NPZDictItemDataset, PersistentDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 1c08aaf2ff..fb11f6ca83 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -97,47 +97,53 @@ def __getitem__(self, index: Union[int, slice, Sequence[int]]): return self._transform(index) -class DatasetGenerator(Dataset): +class DatasetFunc(Dataset): """ - Generator to provide dataset items with specified `func`. + Execute function on the input dataset and leverage the output to act as a new Dataset. It can be used to load / fetch the basic dataset items, like the list of `image, label` paths. Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc. + The `data` arg of `Dataset` will be applied to the first arg of callable `func`. Usage example:: - data_list = DatasetGenerator( + data_list = DatasetFunc( + data="path to file", func=monai.data.load_decathlon_datalist, - data_list_file_path="path to file", data_list_key="validation", base_dir="path to base dir", ) # partition dataset for every rank - data_partition = DatasetGenerator( - func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()], + data_partition = DatasetFunc( data=data_list, + func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()], num_partitions=torch.distributed.get_world_size(), ) dataset = Dataset(data=data_partition, transform=transforms) Args: + data: input data for the func to process, will apply to `func` as the first arg. func: callable function to generate dataset items. - kwargs: arguments for the `func`. + kwargs: other arguments for the `func` except for the first arg. """ - def __init__(self, func: Callable, **kwargs) -> None: + def __init__(self, data: Any, func: Callable, **kwargs) -> None: + self.src = data self.func = func self.kwargs = kwargs super().__init__(self.reset(), transform=None) - def reset(self, func: Optional[Callable] = None, **kwargs) -> Sequence: + def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs) -> Sequence: """ Reset the dataset items with specified `func`. Args: + data: if not None, execute `func` on it, default to `self.src`. func: if not None, execute the `func` with specified `kwargs`, default to `self.func`. + kwargs: other arguments for the `func` except for the first arg. """ - self.data = self.func(**self.kwargs) if func is None else func(**kwargs) + src = self.src if data is None else data + self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs) return self.data diff --git a/tests/test_dataset_generator.py b/tests/test_dataset_func.py similarity index 80% rename from tests/test_dataset_generator.py rename to tests/test_dataset_func.py index 7096289c44..d57f4862cb 100644 --- a/tests/test_dataset_generator.py +++ b/tests/test_dataset_func.py @@ -14,10 +14,10 @@ import tempfile import unittest -from monai.data import Dataset, DatasetGenerator, load_decathlon_datalist, partition_dataset +from monai.data import Dataset, DatasetFunc, load_decathlon_datalist, partition_dataset -class TestDatasetGenerator(unittest.TestCase): +class TestDatasetFunc(unittest.TestCase): def test_seg_values(self): with tempfile.TemporaryDirectory() as tempdir: # prepare test datalist file @@ -36,12 +36,12 @@ def test_seg_values(self): with open(file_path, "w") as json_file: json_file.write(json_str) - data_list = DatasetGenerator( - func=load_decathlon_datalist, data_list_file_path=file_path, data_list_key="training", base_dir=tempdir + data_list = DatasetFunc( + data=file_path, func=load_decathlon_datalist, data_list_key="training", base_dir=tempdir, ) # partition dataset for train / validation - data_partition = DatasetGenerator( - func=lambda **kwargs: partition_dataset(**kwargs)[0], data=data_list, num_partitions=2 + data_partition = DatasetFunc( + data=data_list, func=lambda **kwargs: partition_dataset(**kwargs)[0], num_partitions=2 ) dataset = Dataset(data=data_partition, transform=None) self.assertEqual(dataset[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) From be10f2b4c6bc90f9e092233a095b422e2780a3ba Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 8 Dec 2021 16:19:26 +0000 Subject: [PATCH 4/8] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_dataset_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py index d57f4862cb..53f4461e96 100644 --- a/tests/test_dataset_func.py +++ b/tests/test_dataset_func.py @@ -37,7 +37,7 @@ def test_seg_values(self): json_file.write(json_str) data_list = DatasetFunc( - data=file_path, func=load_decathlon_datalist, data_list_key="training", base_dir=tempdir, + data=file_path, func=load_decathlon_datalist, data_list_key="training", base_dir=tempdir ) # partition dataset for train / validation data_partition = DatasetFunc( From c75b880809d317c3c2cb6760aa1cce1ab96fb4d8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 9 Dec 2021 00:24:17 +0800 Subject: [PATCH 5/8] [DLMED] fix wrong test Signed-off-by: Nic Ma --- tests/test_dataset_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataset_func.py b/tests/test_dataset_func.py index 53f4461e96..b3f6b95403 100644 --- a/tests/test_dataset_func.py +++ b/tests/test_dataset_func.py @@ -41,7 +41,7 @@ def test_seg_values(self): ) # partition dataset for train / validation data_partition = DatasetFunc( - data=data_list, func=lambda **kwargs: partition_dataset(**kwargs)[0], num_partitions=2 + data=data_list, func=lambda x, **kwargs: partition_dataset(x, **kwargs)[0], num_partitions=2 ) dataset = Dataset(data=data_partition, transform=None) self.assertEqual(dataset[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) From b292948f97cc2898ba4829780b605446f4884118 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 10 Dec 2021 10:56:17 +0800 Subject: [PATCH 6/8] [DLMED] simplify according to comments Signed-off-by: Nic Ma --- monai/data/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index fb11f6ca83..c8f4ef3811 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -127,10 +127,11 @@ class DatasetFunc(Dataset): """ def __init__(self, data: Any, func: Callable, **kwargs) -> None: + super().__init__(data=None, transform=None) # type:ignore self.src = data self.func = func self.kwargs = kwargs - super().__init__(self.reset(), transform=None) + self.reset() def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs) -> Sequence: """ From d069a7d98c6968aeba4db5c60d1e43bb8562c239 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 10 Dec 2021 16:40:03 +0800 Subject: [PATCH 7/8] [DLMED] remove return Signed-off-by: Nic Ma --- monai/data/dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c8f4ef3811..ccd831ee0f 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -133,7 +133,7 @@ def __init__(self, data: Any, func: Callable, **kwargs) -> None: self.kwargs = kwargs self.reset() - def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs) -> Sequence: + def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **kwargs): """ Reset the dataset items with specified `func`. @@ -146,8 +146,6 @@ def reset(self, data: Optional[Any] = None, func: Optional[Callable] = None, **k src = self.src if data is None else data self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs) - return self.data - class PersistentDataset(Dataset): """ From 8c127d3bc1154e76649d18aa1b3e73bade059a04 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 10 Dec 2021 17:25:28 +0800 Subject: [PATCH 8/8] [DLMED] update rtol for CI Signed-off-by: Nic Ma --- tests/test_scale_intensity_range_percentilesd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scale_intensity_range_percentilesd.py b/tests/test_scale_intensity_range_percentilesd.py index ac2118d99f..0fcda21feb 100644 --- a/tests/test_scale_intensity_range_percentilesd.py +++ b/tests/test_scale_intensity_range_percentilesd.py @@ -35,7 +35,7 @@ def test_scaling(self): scaler = ScaleIntensityRangePercentilesd( keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max ) - assert_allclose(p(expected), scaler(data)["img"]) + assert_allclose(p(expected), scaler(data)["img"], rtol=1e-4) def test_relative_scaling(self): img = self.imt