diff --git a/d2go/data/disk_cache.py b/d2go/data/disk_cache.py new file mode 100644 index 00000000..e24605f1 --- /dev/null +++ b/d2go/data/disk_cache.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import atexit +import logging +import pickle +import shutil +import uuid + +import numpy as np +import torch.utils.data as data +from detectron2.utils import comm +from detectron2.utils.logger import log_every_n_seconds + +logger = logging.getLogger(__name__) + + +def _local_master_gather(func, check_equal=False): + if comm.get_local_rank() == 0: + x = func() + assert x is not None + else: + x = None + x_all = comm.all_gather(x) + x_local_master = [x for x in x_all if x is not None] + + if check_equal: + master = x_local_master[0] + assert all(x == master for x in x_local_master), x_local_master + + return x_local_master + + +class DiskCachedDatasetFromList(data.Dataset): + """ + Wrap a list to a torch Dataset, the underlying storage is off-loaded to disk to + save RAM usage. + """ + + CACHE_DIR = "/tmp/DatasetFromList_cache" + _OCCUPIED_CACHE_DIRS = set() + + def __init__(self, lst, strategy="batched_static"): + """ + Args: + lst (list): a list which contains elements to produce. + strategy (str): strategy of using diskcache, supported strategies: + - native: saving each item individually. + - batched_static: group N items together, where N is calculated from + the average item size. + """ + self._lst = lst + self._diskcache_strategy = strategy + + def _serialize(data): + buffer = pickle.dumps(data, protocol=-1) + return np.frombuffer(buffer, dtype=np.uint8) + + logger.info( + "Serializing {} elements to byte tensors and concatenating them all ...".format( + len(self._lst) + ) + ) + self._lst = [_serialize(x) for x in self._lst] + # TODO: only enabling DiskCachedDataset for large enough dataset + logger.info( + "Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2) + ) + self._initialize_diskcache() + + def _initialize_diskcache(self): + from mobile_cv.common.misc.local_cache import LocalCache + + cache_dir = "{}/{}".format( + DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8] + ) + cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir + logger.info("Creating diskcache database in: {}".format(cache_dir)) + self._cache = LocalCache(cache_dir=cache_dir, num_shards=8) + # self._cache.cache.clear(retry=True) # seems faster if index exists + + if comm.get_local_rank() == 0: + DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir) + + if self._diskcache_strategy == "naive": + for i, item in enumerate(self._lst): + ret = self._write_to_local_db((i, item)) + assert ret, "Error writing index {} to local db".format(i) + pct = 100.0 * i / len(self._lst) + self._log_progress(pct) + + # NOTE: each item might be small in size (hundreds of bytes), + # writing million of them can take a pretty long time (hours) + # because of frequent disk access. One solution is grouping a batch + # of items into larger blob. + elif self._diskcache_strategy == "batched_static": + TARGET_BYTES = 50 * 1024 + average_bytes = np.average( + [ + self._lst[int(x)].size + for x in np.linspace(0, len(self._lst) - 1, 1000) + ] + ) + self._chuck_size = max(1, int(TARGET_BYTES / average_bytes)) + logger.info( + "Average data size: {} bytes; target chuck data size {} KiB;" + " {} items per chuck; {} chucks in total".format( + average_bytes, + TARGET_BYTES / 1024, + self._chuck_size, + int(len(self._lst) / self._chuck_size), + ) + ) + for i in range(0, len(self._lst), self._chuck_size): + chunk = self._lst[i : i + self._chuck_size] + chunk_i = int(i / self._chuck_size) + ret = self._write_to_local_db((chunk_i, chunk)) + assert ret, "Error writing index {} to local db".format(chunk_i) + pct = 100.0 * i / len(self._lst) + self._log_progress(pct) + + # NOTE: instead of using fixed chuck size, items can be grouped dynamically + elif self._diskcache_strategy == "batched_dynamic": + raise NotImplementedError() + + else: + raise NotImplementedError(self._diskcache_strategy) + + comm.synchronize() + logger.info( + "Finished writing to local disk, db size: {:.2f} MiB".format( + self._cache.cache.volume() / 1024 ** 2 + ) + ) + # Optional sync for some strategies + if self._diskcache_strategy == "batched_static": + # propagate chuck size and make sure all local rank 0 uses the same value + self._chuck_size = _local_master_gather( + lambda: self._chuck_size, check_equal=True + )[0] + + # free the memory of self._lst + self._size = _local_master_gather(lambda: len(self._lst), check_equal=True)[0] + del self._lst + + def _write_to_local_db(self, task): + index, record = task + db_path = str(index) + # suc = self._cache.load(lambda path, x: x, db_path, record) + # record = BytesIO(np.random.bytes(np.random.randint(70000, 90000))) + suc = self._cache.cache.set(db_path, record, retry=True) + return suc + + def _log_progress(self, percentage): + log_every_n_seconds( + logging.INFO, + "({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format( + percentage, + len(self._cache.cache), + self._cache.cache.volume() / 1024 ** 2, + ), + n=10, + ) + + def __len__(self): + if self._diskcache_strategy == "batched_static": + return self._size + else: + raise NotImplementedError() + + def __getitem__(self, idx): + if self._diskcache_strategy == "naive": + bytes = memoryview(self._cache.cache[str(idx)]) + return pickle.loads(bytes) + + elif self._diskcache_strategy == "batched_static": + chunk_i, residual = divmod(idx, self._chuck_size) + chunk = self._cache.cache[str(chunk_i)] + bytes = memoryview(chunk[residual]) + return pickle.loads(bytes) + + else: + raise NotImplementedError() + + @classmethod + def get_all_cache_dirs(cls): + """return all the ocupied cache dirs of DiskCachedDatasetFromList""" + return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS + + def get_cache_dir(self): + """return the current cache dirs of DiskCachedDatasetFromList instance""" + return self._cache.cache_dir + + @staticmethod + def _clean_up_cache_dir(cache_dir, **kwargs): + print("Cleaning up cache dir: {}".format(cache_dir)) + shutil.rmtree( + cache_dir, + onerror=lambda func, path, ex: print( + "Catch error when removing {}; func: {}; exc_info: {}".format( + path, func, ex + ) + ), + ) + + @staticmethod + @atexit.register + def _clean_up_all(): + # in case the program exists unexpectly, clean all the cache dirs created by + # this session. + if comm.get_local_rank() == 0: + for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs(): + DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir) + + def __del__(self): + # when data loader goes are GC-ed, remove the cache dir. This is needed to not + # waste disk space in case that multiple data loaders are used, eg. running + # evaluations on multiple datasets during training. + if comm.get_local_rank() == 0: + DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir) + DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir) diff --git a/d2go/data/utils.py b/d2go/data/utils.py index da2f65e2..22782c93 100644 --- a/d2go/data/utils.py +++ b/d2go/data/utils.py @@ -7,11 +7,9 @@ import json import logging import os -import pickle import re import shutil import tempfile -import uuid from collections import defaultdict from unittest import mock @@ -30,7 +28,6 @@ ) from detectron2.utils import comm from detectron2.utils.file_io import PathManager -from detectron2.utils.logger import log_every_n_seconds logger = logging.getLogger(__name__) @@ -387,212 +384,6 @@ def update_cfg_if_using_adhoc_dataset(cfg): return cfg -def _local_master_gather(func, check_equal=False): - if comm.get_local_rank() == 0: - x = func() - assert x is not None - else: - x = None - x_all = comm.all_gather(x) - x_local_master = [x for x in x_all if x is not None] - - if check_equal: - master = x_local_master[0] - assert all(x == master for x in x_local_master), x_local_master - - return x_local_master - - -class DiskCachedDatasetFromList(data.Dataset): - """ - Wrap a list to a torch Dataset, the underlying storage is off-loaded to disk to - save RAM usage. - """ - - CACHE_DIR = "/tmp/DatasetFromList_cache" - _OCCUPIED_CACHE_DIRS = set() - - def __init__(self, lst, strategy="batched_static"): - """ - Args: - lst (list): a list which contains elements to produce. - strategy (str): strategy of using diskcache, supported strategies: - - native: saving each item individually. - - batched_static: group N items together, where N is calculated from - the average item size. - """ - self._lst = lst - self._diskcache_strategy = strategy - - def _serialize(data): - buffer = pickle.dumps(data, protocol=-1) - return np.frombuffer(buffer, dtype=np.uint8) - - logger.info( - "Serializing {} elements to byte tensors and concatenating them all ...".format( - len(self._lst) - ) - ) - self._lst = [_serialize(x) for x in self._lst] - # TODO: only enabling DiskCachedDataset for large enough dataset - logger.info( - "Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2) - ) - self._initialize_diskcache() - - def _initialize_diskcache(self): - from mobile_cv.common.misc.local_cache import LocalCache - - cache_dir = "{}/{}".format( - DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8] - ) - cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir - logger.info("Creating diskcache database in: {}".format(cache_dir)) - self._cache = LocalCache(cache_dir=cache_dir, num_shards=8) - # self._cache.cache.clear(retry=True) # seems faster if index exists - - if comm.get_local_rank() == 0: - DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir) - - if self._diskcache_strategy == "naive": - for i, item in enumerate(self._lst): - ret = self._write_to_local_db((i, item)) - assert ret, "Error writing index {} to local db".format(i) - pct = 100.0 * i / len(self._lst) - self._log_progress(pct) - - # NOTE: each item might be small in size (hundreds of bytes), - # writing million of them can take a pretty long time (hours) - # because of frequent disk access. One solution is grouping a batch - # of items into larger blob. - elif self._diskcache_strategy == "batched_static": - TARGET_BYTES = 50 * 1024 - average_bytes = np.average( - [ - self._lst[int(x)].size - for x in np.linspace(0, len(self._lst) - 1, 1000) - ] - ) - self._chuck_size = max(1, int(TARGET_BYTES / average_bytes)) - logger.info( - "Average data size: {} bytes; target chuck data size {} KiB;" - " {} items per chuck; {} chucks in total".format( - average_bytes, - TARGET_BYTES / 1024, - self._chuck_size, - int(len(self._lst) / self._chuck_size), - ) - ) - for i in range(0, len(self._lst), self._chuck_size): - chunk = self._lst[i : i + self._chuck_size] - chunk_i = int(i / self._chuck_size) - ret = self._write_to_local_db((chunk_i, chunk)) - assert ret, "Error writing index {} to local db".format(chunk_i) - pct = 100.0 * i / len(self._lst) - self._log_progress(pct) - - # NOTE: instead of using fixed chuck size, items can be grouped dynamically - elif self._diskcache_strategy == "batched_dynamic": - raise NotImplementedError() - - else: - raise NotImplementedError(self._diskcache_strategy) - - comm.synchronize() - logger.info( - "Finished writing to local disk, db size: {:.2f} MiB".format( - self._cache.cache.volume() / 1024 ** 2 - ) - ) - # Optional sync for some strategies - if self._diskcache_strategy == "batched_static": - # propagate chuck size and make sure all local rank 0 uses the same value - self._chuck_size = _local_master_gather( - lambda: self._chuck_size, check_equal=True - )[0] - - # free the memory of self._lst - self._size = _local_master_gather(lambda: len(self._lst), check_equal=True)[0] - del self._lst - - def _write_to_local_db(self, task): - index, record = task - db_path = str(index) - # suc = self._cache.load(lambda path, x: x, db_path, record) - # record = BytesIO(np.random.bytes(np.random.randint(70000, 90000))) - suc = self._cache.cache.set(db_path, record, retry=True) - return suc - - def _log_progress(self, percentage): - log_every_n_seconds( - logging.INFO, - "({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format( - percentage, - len(self._cache.cache), - self._cache.cache.volume() / 1024 ** 2, - ), - n=10, - ) - - def __len__(self): - if self._diskcache_strategy == "batched_static": - return self._size - else: - raise NotImplementedError() - - def __getitem__(self, idx): - if self._diskcache_strategy == "naive": - bytes = memoryview(self._cache.cache[str(idx)]) - return pickle.loads(bytes) - - elif self._diskcache_strategy == "batched_static": - chunk_i, residual = divmod(idx, self._chuck_size) - chunk = self._cache.cache[str(chunk_i)] - bytes = memoryview(chunk[residual]) - return pickle.loads(bytes) - - else: - raise NotImplementedError() - - @classmethod - def get_all_cache_dirs(cls): - """return all the ocupied cache dirs of DiskCachedDatasetFromList""" - return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS - - def get_cache_dir(self): - """return the current cache dirs of DiskCachedDatasetFromList instance""" - return self._cache.cache_dir - - @staticmethod - def _clean_up_cache_dir(cache_dir, **kwargs): - print("Cleaning up cache dir: {}".format(cache_dir)) - shutil.rmtree( - cache_dir, - onerror=lambda func, path, ex: print( - "Catch error when removing {}; func: {}; exc_info: {}".format( - path, func, ex - ) - ), - ) - - @staticmethod - @atexit.register - def _clean_up_all(): - # in case the program exists unexpectly, clean all the cache dirs created by - # this session. - if comm.get_local_rank() == 0: - for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs(): - DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir) - - def __del__(self): - # when data loader goes are GC-ed, remove the cache dir. This is needed to not - # waste disk space in case that multiple data loaders are used, eg. running - # evaluations on multiple datasets during training. - if comm.get_local_rank() == 0: - DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir) - DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir) - - class _FakeListObj(list): def __init__(self, size): self.size = size @@ -635,6 +426,8 @@ def enable_disk_cached_dataset(cfg): return def _patched_dataset_from_list(lst, **kwargs): + from d2go.data.disk_cache import DiskCachedDatasetFromList + logger.info("Patch DatasetFromList with DiskCachedDatasetFromList") return DiskCachedDatasetFromList(lst) diff --git a/setup.py b/setup.py index 215b2736..329a9060 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "pytorch-lightning @ git+https://github.com/PyTorchLightning/pytorch-lightning.git@86b177ebe", "opencv-python", "parameterized", + "diskcache", # TODO: move to mobile_cv ] diff --git a/tests/data/test_data_loader.py b/tests/data/test_data_loader.py index ef50697a..b6609b17 100644 --- a/tests/data/test_data_loader.py +++ b/tests/data/test_data_loader.py @@ -7,7 +7,8 @@ import unittest import torch -from d2go.data.utils import DiskCachedDatasetFromList, enable_disk_cached_dataset +from d2go.data.disk_cache import DiskCachedDatasetFromList +from d2go.data.utils import enable_disk_cached_dataset from d2go.runner import create_runner from d2go.utils.testing.data_loader_helper import ( create_fake_detection_data_loader,