Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ability to redefine cache timeout #571

Closed
14 changes: 13 additions & 1 deletion test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from _utils._common_utils_for_test import create_temp_dir, create_temp_files, get_name, reset_after_n_next_calls

from torch.utils.data import DataLoader

from torchdata.dataloader2.adapter import CacheTimeout
from torchdata.datapipes.iter import (
Bz2FileLoader,
CSVDictParser,
Expand Down Expand Up @@ -629,7 +631,7 @@ def _write_text_files(self):
def _slow_fn(tmpdirname, x):
with open(os.path.join(tmpdirname, str(os.getpid())), "w") as pid_fh:
pid_fh.write("anything")
time.sleep(2)
time.sleep(10)
return (x, "str")

def test_disk_cache_locks(self):
Expand All @@ -650,6 +652,16 @@ def test_disk_cache_locks(self):
self.assertEqual(2, len(all_files))
self.assertEqual("str", result[0][1])

# cleanup cached files
for f in os.listdir(tmpdirname):
os.remove(os.path.join(tmpdirname, f))

dp = StreamReader(dp)
dp = CacheTimeout(2)(dp) # Calling adapter manually to work with classic DataLoader
dl = DataLoader(dp, num_workers=10, multiprocessing_context="spawn", batch_size=1, collate_fn=_unbatch)
with self.assertRaisesRegex(Exception, "OnDiskCache Exception"):
result = list(dl)

# TODO(120): this test currently only covers reading from local
# filesystem. It needs to be modified once test data can be stored on
# gdrive/s3/onedrive
Expand Down
30 changes: 30 additions & 0 deletions torchdata/dataloader2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import torch

from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.iter.util.cacheholder import _WaitPendingCacheItemIterDataPipe

__all__ = [
"Adapter",
"CacheTimeout",
"Shuffle",
]

Expand Down Expand Up @@ -45,3 +47,31 @@ def __init__(self, enable=True):

def __call__(self, datapipe: IterDataPipe) -> IterDataPipe:
return torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=self.enable)


class CacheTimeout(Adapter):
r"""
CacheTimeout DataPipes adapter allows control over timeouts of all existing EndOnDiskCacheHolder (``end_caching``)
DataPipes in the graph. Usefull when cached pipeline takes to long to execute (ex. slow file downloading).

Args:
timeout: int - amount of seconds parallel processes will wait for cached files to appear.

Example:
>>> dl = DataLoader2(dp, [CacheTimeout(600)])
"""

def __init__(self, timeout=None):
if timeout is None:
raise ValueError("timeout should be integer")
self.timeout = timeout

def __call__(self, datapipe: IterDataPipe) -> IterDataPipe:
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
all_pipes = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
cache_locks = {pipe for pipe in all_pipes if isinstance(pipe, _WaitPendingCacheItemIterDataPipe)}

for cache_lock in cache_locks:
cache_lock.set_timeout(self.timeout)

return datapipe
34 changes: 21 additions & 13 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
import hashlib
import inspect
import os.path
Expand Down Expand Up @@ -292,17 +291,26 @@ def _is_promise_pending(promise_filename):
return os.path.exists(promise_filename)


def _wait_promise_fn(timeout, filename):
promise_filename = _find_promise_file(filename)
start = time.time()
while _is_promise_pending(promise_filename):
time.sleep(0.01)
if time.time() - start > timeout:
raise Exception(
f"OnDiskCache Exception: {filename} expected to be written by different process, "
+ f"but file is not ready in {timeout} seconds."
)
return filename
class _WaitPendingCacheItemIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe, timeout=300):
self.source_datapipe = source_datapipe
self.timeout = timeout

def set_timeout(self, timeout):
self.timeout = timeout

def __iter__(self):
for filename in self.source_datapipe:
promise_filename = _find_promise_file(filename)
start = time.time()
while _is_promise_pending(promise_filename):
time.sleep(0.01)
if time.time() - start > self.timeout:
raise Exception(
f"OnDiskCache Exception: {filename} expected to be written by different process, "
+ f"but file is not ready in {self.timeout} seconds."
)
yield filename


class _FulfilledPromisesIterDataPipe(IterDataPipe):
Expand Down Expand Up @@ -398,7 +406,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals

_filepath_fn, _hash_dict, _hash_type, _ = OnDiskCacheHolderIterDataPipe._temp_dict[cache_holder]
cached_dp = cache_holder._end_caching()
cached_dp = cached_dp.map(functools.partial(_wait_promise_fn, timeout))
cached_dp = _WaitPendingCacheItemIterDataPipe(cached_dp, timeout=timeout)
cached_dp = FileLister(cached_dp, recursive=True)

if same_filepath_fn:
Expand Down