Skip to content

Commit

Permalink
unify progress bar for datasets (Project-MONAI#1625)
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>
  • Loading branch information
rijobro authored Feb 23, 2021
1 parent 6bdc8c6 commit db2dbb0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
8 changes: 6 additions & 2 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def __init__(
data = self._generate_data_list(dataset_dir)
if transform == ():
transform = LoadImaged("image")
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
)

def randomize(self, data: Optional[Any] = None) -> None:
self.rann = self.R.random()
Expand Down Expand Up @@ -275,7 +277,9 @@ def __init__(
self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys)
if transform == ():
transform = LoadImaged(["image", "label"])
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
)

def get_indices(self) -> np.ndarray:
"""
Expand Down
21 changes: 13 additions & 8 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ class Dataset(_TorchDataset):
}, }, }]
"""

def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None:
def __init__(self, data: Sequence, transform: Optional[Callable] = None, progress: bool = True) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
transform: a callable data transform on input data.
progress: whether to display a progress bar.
"""
self.data = data
self.transform = transform
self.progress = progress

def __len__(self) -> int:
return len(self.data)
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(
transform: Union[Sequence[Callable], Callable],
cache_dir: Optional[Union[Path, str]] = None,
hash_func: Callable[..., bytes] = pickle_hashing,
progress: bool = True,
) -> None:
"""
Args:
Expand All @@ -129,10 +132,11 @@ def __init__(
If the cache_dir doesn't exist, will automatically create it.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
progress: whether to display a progress bar.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
super().__init__(data=data, transform=transform, progress=progress)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
if self.cache_dir is not None:
Expand Down Expand Up @@ -345,7 +349,7 @@ def __init__(
lmdb_kwargs: additional keyword arguments to the lmdb environment.
for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class
"""
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func)
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, progress=progress)
if not self.cache_dir:
raise ValueError("cache_dir must be specified.")
self.db_file = self.cache_dir / f"{db_name}.lmdb"
Expand All @@ -354,14 +358,13 @@ def __init__(
if not self.lmdb_kwargs.get("map_size", 0):
self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size
self._read_env = None
self.progress = progress
print(f"Accessing lmdb file: {self.db_file.absolute()}.")

def _fill_cache_start_reader(self):
# create cache
self.lmdb_kwargs["readonly"] = False
env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
if not has_tqdm:
if self.progress and not has_tqdm:
warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.")
for item in tqdm(self.data) if has_tqdm and self.progress else self.data:
key = self.hash_func(item)
Expand Down Expand Up @@ -470,6 +473,7 @@ def __init__(
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: Optional[int] = None,
progress: bool = True,
) -> None:
"""
Args:
Expand All @@ -481,10 +485,11 @@ def __init__(
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker processes to use.
If num_workers is None then the number returned by os.cpu_count() is used.
progress: whether to display a progress bar.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
super().__init__(data=data, transform=transform, progress=progress)
self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data))
self.num_workers = num_workers
if self.num_workers is not None:
Expand All @@ -494,10 +499,10 @@ def __init__(
def _fill_cache(self) -> List:
if self.cache_num <= 0:
return []
if not has_tqdm:
if self.progress and not has_tqdm:
warnings.warn("tqdm is not installed, will not show the caching progress bar.")
with ThreadPool(self.num_workers) as p:
if has_tqdm:
if self.progress and has_tqdm:
return list(
tqdm(
p.imap(self._load_cache_item, range(self.cache_num)),
Expand Down

0 comments on commit db2dbb0

Please sign in to comment.