Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3734 Enhance CacheDataset to avoid duplicated cache #3739

Merged
merged 16 commits into from
Feb 1, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[DLMED] update according to comments
Signed-off-by: Nic Ma <nma@nvidia.com>
Nic-Ma committed Jan 31, 2022
commit 5b4715da9242e563e34e3691267f66133bf4df95
31 changes: 18 additions & 13 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
@@ -707,8 +707,8 @@ def __init__(
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.set_num = cache_num
self.set_rate = cache_rate
self.set_num = cache_num # tracking the user-provided `cache_num` option
self.set_rate = cache_rate # tracking the user-provided `cache_rate` option
self.progress = progress
self.copy_cache = copy_cache
self.as_contiguous = as_contiguous
@@ -718,8 +718,7 @@ def __init__(
if self.num_workers is not None:
self.num_workers = max(int(self.num_workers), 1)
self.cache_num = 0
self._cache_keys: List = []
self._cache: List = []
self._cache: Union[List, Dict] = []
self.set_data(data)

def set_data(self, data: Sequence):
@@ -731,16 +730,21 @@ def set_data(self, data: Sequence):
generated cache content.

"""

def _compute_cache():
self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data))
return self._fill_cache()

if self.hash_as_key:
# only compute cache for the unique items of dataset
mapping = {self.hash_func(v): v for v in data}
self._cache_keys = list(mapping)
self.data = list(mapping.values())
cache_ = _compute_cache()
self._cache = dict(zip(list(mapping)[: self.cache_num], cache_))
self.data = data
else:
self.data = data
self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data))
self._cache = self._fill_cache()
self.data = data
self._cache = _compute_cache()

def _fill_cache(self) -> List:
if self.cache_num <= 0:
@@ -775,20 +779,21 @@ def _load_cache_item(self, idx: int):
return item

def _transform(self, index: int):
index_: Any = index
if self.hash_as_key:
key = self.hash_func(self.data[index])
if key in self._cache_keys[: self.cache_num]:
if key in self._cache:
# if existing in cache, get the index
index = self._cache_keys.index(key)
index_ = key # if using hash as cache keys, set the key

if index % len(self) >= self.cache_num: # support negative index
if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index
# no cache for this index, execute all the transforms directly
return super()._transform(index)
return super()._transform(index_)
# load data from cache and execute from the first random transform
start_run = False
if self._cache is None:
self._cache = self._fill_cache()
data = self._cache[index]
data = self._cache[index_]
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms: