diff --git a/monai/data/__init__.py b/monai/data/__init__.py index fca170335b..edfcc02996 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -58,6 +58,7 @@ pickle_hashing, rectify_header_sform_qform, rep_scalar_to_batch, + resample_datalist, select_cross_validation_folds, set_rnd, sorted_dict, diff --git a/monai/data/utils.py b/monai/data/utils.py index 880ceed7b8..4b29004ad9 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -66,6 +66,7 @@ "is_supported_format", "partition_dataset", "partition_dataset_classes", + "resample_datalist", "select_cross_validation_folds", "json_hashing", "pickle_hashing", @@ -991,6 +992,31 @@ def partition_dataset_classes( return datasets +def resample_datalist(data: Sequence, factor: float, random_pick: bool = False, seed: int = 0): + """ + Utility function to resample the loaded datalist for training, for example: + If factor < 1.0, randomly pick part of the datalist and set to Dataset, useful to quickly test the program. + If factor > 1.0, repeat the datalist to enhance the Dataset. + + Args: + data: original datalist to scale. + factor: scale factor for the datalist, for example, factor=4.5, repeat the datalist 4 times and plus + 50% of the original datalist. + random_pick: whether to randomly pick data if scale factor has decimal part. + seed: random seed to randomly pick data. + + """ + scale, repeats = math.modf(factor) + ret: List = list() + + for _ in range(int(repeats)): + ret.extend(list(deepcopy(data))) + if scale > 1e-6: + ret.extend(partition_dataset(data=data, ratios=[scale, 1 - scale], shuffle=random_pick, seed=seed)[0]) + + return ret + + def select_cross_validation_folds(partitions: Sequence[Iterable], folds: Union[Sequence[int], int]) -> List: """ Select cross validation data based on data partitions and specified fold index. diff --git a/tests/test_inverse.py b/tests/test_inverse.py index d547fe7595..d662441494 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -450,7 +450,7 @@ def test_inverse_inferred_seg(self, extra_transform): batch_size = 10 # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 + num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) diff --git a/tests/test_resample_datalist.py b/tests/test_resample_datalist.py new file mode 100644 index 0000000000..1d92e431cd --- /dev/null +++ b/tests/test_resample_datalist.py @@ -0,0 +1,40 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data import resample_datalist + +TEST_CASE_1 = [ + {"data": [1, 2, 3, 4, 5], "factor": 2.5, "random_pick": True, "seed": 123}, + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 2, 4, 5], +] + +TEST_CASE_2 = [ + {"data": [1, 2, 3, 4, 5], "factor": 2.5, "random_pick": False, "seed": 0}, + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3], +] + +TEST_CASE_3 = [{"data": [1, 2, 3, 4, 5], "factor": 0.6, "random_pick": True, "seed": 123}, [2, 4, 5]] + + +class TestResampleDatalist(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value_shape(self, input_param, expected): + result = resample_datalist(**input_param) + np.testing.assert_allclose(result, expected) + + +if __name__ == "__main__": + unittest.main()