Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3531 Add args to subclass of CacheDataset #3532

Merged
merged 10 commits into from
Dec 22, 2021
52 changes: 48 additions & 4 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class MedNISTDataset(Randomizable, CacheDataset):
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
if 0 a single thread will be used. Default is 0.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.

Raises:
ValueError: When ``root_dir`` is not a directory.
Expand All @@ -75,6 +81,8 @@ def __init__(
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int = 0,
progress: bool = True,
copy_cache: bool = True,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand All @@ -87,7 +95,14 @@ def __init__(
dataset_dir = root_dir / self.dataset_folder_name
self.num_class = 0
if download:
download_and_extract(self.resource, tarfile_name, root_dir, self.md5)
download_and_extract(
url=self.resource,
filepath=tarfile_name,
output_dir=root_dir,
hash_val=self.md5,
hash_type="md5",
progress=progress,
)

if not dataset_dir.is_dir():
raise RuntimeError(
Expand All @@ -97,7 +112,14 @@ def __init__(
if transform == ():
transform = LoadImaged("image")
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
self,
data=data,
transform=transform,
cache_num=cache_num,
cache_rate=cache_rate,
num_workers=num_workers,
progress=progress,
copy_cache=copy_cache,
)

def randomize(self, data: List[int]) -> None:
Expand Down Expand Up @@ -177,6 +199,12 @@ class DecathlonDataset(Randomizable, CacheDataset):
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
if 0 a single thread will be used. Default is 0.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.

Raises:
ValueError: When ``root_dir`` is not a directory.
Expand Down Expand Up @@ -241,6 +269,8 @@ def __init__(
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int = 0,
progress: bool = True,
copy_cache: bool = True,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand All @@ -253,7 +283,14 @@ def __init__(
dataset_dir = root_dir / task
tarfile_name = f"{dataset_dir}.tar"
if download:
download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task])
download_and_extract(
url=self.resource[task],
filepath=tarfile_name,
output_dir=root_dir,
hash_val=self.md5[task],
hash_type="md5",
progress=progress,
)

if not dataset_dir.exists():
raise RuntimeError(
Expand All @@ -277,7 +314,14 @@ def __init__(
if transform == ():
transform = LoadImaged(["image", "label"])
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
self,
data=data,
transform=transform,
cache_num=cache_num,
cache_rate=cache_rate,
num_workers=num_workers,
progress=progress,
copy_cache=copy_cache,
)

def get_indices(self) -> np.ndarray:
Expand Down
6 changes: 6 additions & 0 deletions monai/apps/pathology/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class SmartCachePatchWSIDataset(SmartCacheDataset):
num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.
If num_replace_workers is None then the number returned by os.cpu_count() is used.
progress: whether to display a progress bar when caching for the first epoch.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cache content
or every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.

"""

Expand All @@ -139,6 +143,7 @@ def __init__(
num_init_workers: Optional[int] = None,
num_replace_workers: Optional[int] = None,
progress: bool = True,
copy_cache: bool = True,
):
patch_wsi_dataset = PatchWSIDataset(
data=data,
Expand All @@ -157,6 +162,7 @@ def __init__(
num_replace_workers=num_replace_workers,
progress=progress,
shuffle=False,
copy_cache=copy_cache,
)


Expand Down
1 change: 1 addition & 0 deletions tests/test_decathlondataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _test_dataset(dataset):
transform=transform,
section="validation",
download=True,
copy_cache=False,
)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
Expand Down
4 changes: 3 additions & 1 deletion tests/test_mednistdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def _test_dataset(dataset):
self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

try: # will start downloading if testing_dir doesn't have the MedNIST files
data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=True)
data = MedNISTDataset(
root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False
)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
Expand Down
1 change: 1 addition & 0 deletions tests/test_smartcache_patch_wsi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"cache_num": 2,
"num_init_workers": 1,
"num_replace_workers": 1,
"copy_cache": False,
},
[
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0]]])},
Expand Down