diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 890d720d0f..90a0f95ced 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -51,6 +51,12 @@ class MedNISTDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. Raises: ValueError: When ``root_dir`` is not a directory. @@ -75,6 +81,8 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, + progress: bool = True, + copy_cache: bool = True, ) -> None: root_dir = Path(root_dir) if not root_dir.is_dir(): @@ -87,7 +95,14 @@ def __init__( dataset_dir = root_dir / self.dataset_folder_name self.num_class = 0 if download: - download_and_extract(self.resource, tarfile_name, root_dir, self.md5) + download_and_extract( + url=self.resource, + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5, + hash_type="md5", + progress=progress, + ) if not dataset_dir.is_dir(): raise RuntimeError( @@ -97,7 +112,14 @@ def __init__( if transform == (): transform = LoadImaged("image") CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, ) def randomize(self, data: List[int]) -> None: @@ -177,6 +199,12 @@ class DecathlonDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. Raises: ValueError: When ``root_dir`` is not a directory. @@ -241,6 +269,8 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, + progress: bool = True, + copy_cache: bool = True, ) -> None: root_dir = Path(root_dir) if not root_dir.is_dir(): @@ -253,7 +283,14 @@ def __init__( dataset_dir = root_dir / task tarfile_name = f"{dataset_dir}.tar" if download: - download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task]) + download_and_extract( + url=self.resource[task], + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5[task], + hash_type="md5", + progress=progress, + ) if not dataset_dir.exists(): raise RuntimeError( @@ -277,7 +314,14 @@ def __init__( if transform == (): transform = LoadImaged(["image", "label"]) CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, ) def get_indices(self) -> np.ndarray: diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index 7d9660349a..bfab7c49da 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -122,6 +122,10 @@ class SmartCachePatchWSIDataset(SmartCacheDataset): num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. If num_replace_workers is None then the number returned by os.cpu_count() is used. progress: whether to display a progress bar when caching for the first epoch. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ @@ -139,6 +143,7 @@ def __init__( num_init_workers: Optional[int] = None, num_replace_workers: Optional[int] = None, progress: bool = True, + copy_cache: bool = True, ): patch_wsi_dataset = PatchWSIDataset( data=data, @@ -157,6 +162,7 @@ def __init__( num_replace_workers=num_replace_workers, progress=progress, shuffle=False, + copy_cache=copy_cache, ) diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index 45870c8661..9a785668ac 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -47,6 +47,7 @@ def _test_dataset(dataset): transform=transform, section="validation", download=True, + copy_cache=False, ) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 51f060ad1b..54fb11135a 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -43,7 +43,9 @@ def _test_dataset(dataset): self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) try: # will start downloading if testing_dir doesn't have the MedNIST files - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=True) + data = MedNISTDataset( + root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False + ) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) if isinstance(e, RuntimeError): diff --git a/tests/test_smartcache_patch_wsi_dataset.py b/tests/test_smartcache_patch_wsi_dataset.py index 98acd4aa03..73583b5eb1 100644 --- a/tests/test_smartcache_patch_wsi_dataset.py +++ b/tests/test_smartcache_patch_wsi_dataset.py @@ -45,6 +45,7 @@ "cache_num": 2, "num_init_workers": 1, "num_replace_workers": 1, + "copy_cache": False, }, [ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0]]])},