-
Notifications
You must be signed in to change notification settings - Fork 203
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #185 The `DiskCachedDatasetFromList` was originally in the `d2go/data/utils.py`, so the class is declared by default. Therefore the clean up call (https://fburl.com/code/cu7hswhx) is always called even when the feature is not enabled. This diff move it to a new place and delay the import, so the clean up won't run. Reviewed By: tglik Differential Revision: D34601363 fbshipit-source-id: 734bb9b2c7957d7437ad40c4bfe60a441ec2f23a
- Loading branch information
1 parent
d369931
commit d3115fa
Showing
4 changed files
with
226 additions
and
210 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
|
||
import atexit | ||
import logging | ||
import pickle | ||
import shutil | ||
import uuid | ||
|
||
import numpy as np | ||
import torch.utils.data as data | ||
from detectron2.utils import comm | ||
from detectron2.utils.logger import log_every_n_seconds | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _local_master_gather(func, check_equal=False): | ||
if comm.get_local_rank() == 0: | ||
x = func() | ||
assert x is not None | ||
else: | ||
x = None | ||
x_all = comm.all_gather(x) | ||
x_local_master = [x for x in x_all if x is not None] | ||
|
||
if check_equal: | ||
master = x_local_master[0] | ||
assert all(x == master for x in x_local_master), x_local_master | ||
|
||
return x_local_master | ||
|
||
|
||
class DiskCachedDatasetFromList(data.Dataset): | ||
""" | ||
Wrap a list to a torch Dataset, the underlying storage is off-loaded to disk to | ||
save RAM usage. | ||
""" | ||
|
||
CACHE_DIR = "/tmp/DatasetFromList_cache" | ||
_OCCUPIED_CACHE_DIRS = set() | ||
|
||
def __init__(self, lst, strategy="batched_static"): | ||
""" | ||
Args: | ||
lst (list): a list which contains elements to produce. | ||
strategy (str): strategy of using diskcache, supported strategies: | ||
- native: saving each item individually. | ||
- batched_static: group N items together, where N is calculated from | ||
the average item size. | ||
""" | ||
self._lst = lst | ||
self._diskcache_strategy = strategy | ||
|
||
def _serialize(data): | ||
buffer = pickle.dumps(data, protocol=-1) | ||
return np.frombuffer(buffer, dtype=np.uint8) | ||
|
||
logger.info( | ||
"Serializing {} elements to byte tensors and concatenating them all ...".format( | ||
len(self._lst) | ||
) | ||
) | ||
self._lst = [_serialize(x) for x in self._lst] | ||
# TODO: only enabling DiskCachedDataset for large enough dataset | ||
logger.info( | ||
"Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2) | ||
) | ||
self._initialize_diskcache() | ||
|
||
def _initialize_diskcache(self): | ||
from mobile_cv.common.misc.local_cache import LocalCache | ||
|
||
cache_dir = "{}/{}".format( | ||
DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8] | ||
) | ||
cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir | ||
logger.info("Creating diskcache database in: {}".format(cache_dir)) | ||
self._cache = LocalCache(cache_dir=cache_dir, num_shards=8) | ||
# self._cache.cache.clear(retry=True) # seems faster if index exists | ||
|
||
if comm.get_local_rank() == 0: | ||
DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir) | ||
|
||
if self._diskcache_strategy == "naive": | ||
for i, item in enumerate(self._lst): | ||
ret = self._write_to_local_db((i, item)) | ||
assert ret, "Error writing index {} to local db".format(i) | ||
pct = 100.0 * i / len(self._lst) | ||
self._log_progress(pct) | ||
|
||
# NOTE: each item might be small in size (hundreds of bytes), | ||
# writing million of them can take a pretty long time (hours) | ||
# because of frequent disk access. One solution is grouping a batch | ||
# of items into larger blob. | ||
elif self._diskcache_strategy == "batched_static": | ||
TARGET_BYTES = 50 * 1024 | ||
average_bytes = np.average( | ||
[ | ||
self._lst[int(x)].size | ||
for x in np.linspace(0, len(self._lst) - 1, 1000) | ||
] | ||
) | ||
self._chuck_size = max(1, int(TARGET_BYTES / average_bytes)) | ||
logger.info( | ||
"Average data size: {} bytes; target chuck data size {} KiB;" | ||
" {} items per chuck; {} chucks in total".format( | ||
average_bytes, | ||
TARGET_BYTES / 1024, | ||
self._chuck_size, | ||
int(len(self._lst) / self._chuck_size), | ||
) | ||
) | ||
for i in range(0, len(self._lst), self._chuck_size): | ||
chunk = self._lst[i : i + self._chuck_size] | ||
chunk_i = int(i / self._chuck_size) | ||
ret = self._write_to_local_db((chunk_i, chunk)) | ||
assert ret, "Error writing index {} to local db".format(chunk_i) | ||
pct = 100.0 * i / len(self._lst) | ||
self._log_progress(pct) | ||
|
||
# NOTE: instead of using fixed chuck size, items can be grouped dynamically | ||
elif self._diskcache_strategy == "batched_dynamic": | ||
raise NotImplementedError() | ||
|
||
else: | ||
raise NotImplementedError(self._diskcache_strategy) | ||
|
||
comm.synchronize() | ||
logger.info( | ||
"Finished writing to local disk, db size: {:.2f} MiB".format( | ||
self._cache.cache.volume() / 1024 ** 2 | ||
) | ||
) | ||
# Optional sync for some strategies | ||
if self._diskcache_strategy == "batched_static": | ||
# propagate chuck size and make sure all local rank 0 uses the same value | ||
self._chuck_size = _local_master_gather( | ||
lambda: self._chuck_size, check_equal=True | ||
)[0] | ||
|
||
# free the memory of self._lst | ||
self._size = _local_master_gather(lambda: len(self._lst), check_equal=True)[0] | ||
del self._lst | ||
|
||
def _write_to_local_db(self, task): | ||
index, record = task | ||
db_path = str(index) | ||
# suc = self._cache.load(lambda path, x: x, db_path, record) | ||
# record = BytesIO(np.random.bytes(np.random.randint(70000, 90000))) | ||
suc = self._cache.cache.set(db_path, record, retry=True) | ||
return suc | ||
|
||
def _log_progress(self, percentage): | ||
log_every_n_seconds( | ||
logging.INFO, | ||
"({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format( | ||
percentage, | ||
len(self._cache.cache), | ||
self._cache.cache.volume() / 1024 ** 2, | ||
), | ||
n=10, | ||
) | ||
|
||
def __len__(self): | ||
if self._diskcache_strategy == "batched_static": | ||
return self._size | ||
else: | ||
raise NotImplementedError() | ||
|
||
def __getitem__(self, idx): | ||
if self._diskcache_strategy == "naive": | ||
bytes = memoryview(self._cache.cache[str(idx)]) | ||
return pickle.loads(bytes) | ||
|
||
elif self._diskcache_strategy == "batched_static": | ||
chunk_i, residual = divmod(idx, self._chuck_size) | ||
chunk = self._cache.cache[str(chunk_i)] | ||
bytes = memoryview(chunk[residual]) | ||
return pickle.loads(bytes) | ||
|
||
else: | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def get_all_cache_dirs(cls): | ||
"""return all the ocupied cache dirs of DiskCachedDatasetFromList""" | ||
return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS | ||
|
||
def get_cache_dir(self): | ||
"""return the current cache dirs of DiskCachedDatasetFromList instance""" | ||
return self._cache.cache_dir | ||
|
||
@staticmethod | ||
def _clean_up_cache_dir(cache_dir, **kwargs): | ||
print("Cleaning up cache dir: {}".format(cache_dir)) | ||
shutil.rmtree( | ||
cache_dir, | ||
onerror=lambda func, path, ex: print( | ||
"Catch error when removing {}; func: {}; exc_info: {}".format( | ||
path, func, ex | ||
) | ||
), | ||
) | ||
|
||
@staticmethod | ||
@atexit.register | ||
def _clean_up_all(): | ||
# in case the program exists unexpectly, clean all the cache dirs created by | ||
# this session. | ||
if comm.get_local_rank() == 0: | ||
for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs(): | ||
DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir) | ||
|
||
def __del__(self): | ||
# when data loader goes are GC-ed, remove the cache dir. This is needed to not | ||
# waste disk space in case that multiple data loaders are used, eg. running | ||
# evaluations on multiple datasets during training. | ||
if comm.get_local_rank() == 0: | ||
DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir) | ||
DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir) |
Oops, something went wrong.