From b8858259b716472d439ce4712c94ecc989293f9d Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 26 Jul 2023 12:15:53 +0100 Subject: [PATCH] Add `async_map_unordered` function, shared by Modal and Python async executors Add support for batching inputs to avoid overwhelming the backend service --- cubed/runtime/executors/asyncio.py | 102 ++++++++++++++++++++ cubed/runtime/executors/modal_async.py | 118 ++++++++--------------- cubed/runtime/executors/python_async.py | 85 ++++++++-------- cubed/runtime/utils.py | 11 +++ cubed/tests/runtime/test_modal_async.py | 37 ++++++- cubed/tests/runtime/test_python_async.py | 54 ++++++++++- 6 files changed, 277 insertions(+), 130 deletions(-) create mode 100644 cubed/runtime/executors/asyncio.py diff --git a/cubed/runtime/executors/asyncio.py b/cubed/runtime/executors/asyncio.py new file mode 100644 index 000000000..0ab209008 --- /dev/null +++ b/cubed/runtime/executors/asyncio.py @@ -0,0 +1,102 @@ +import asyncio +import copy +import time +from asyncio import Future +from typing import Any, AsyncIterator, Callable, Dict, Iterable, List, Optional, Tuple + +from cubed.runtime.backup import should_launch_backup +from cubed.runtime.utils import batched + + +async def async_map_unordered( + create_futures_func: Callable[..., List[Tuple[Any, Future]]], + input: Iterable[Any], + use_backups: bool = False, + create_backup_futures_func: Optional[ + Callable[..., List[Tuple[Any, Future]]] + ] = None, + batch_size: Optional[int] = None, + return_stats: bool = False, + name: Optional[str] = None, + **kwargs, +) -> AsyncIterator[Any]: + """ + Asynchronous parallel map over an iterable input, with support for backups and batching. + """ + + if create_backup_futures_func is None: + create_backup_futures_func = create_futures_func + + if batch_size is None: + inputs = input + else: + input_batches = batched(input, batch_size) + inputs = next(input_batches) + + task_create_tstamp = time.time() + tasks = {task: i for i, task in create_futures_func(inputs, **kwargs)} + pending = set(tasks.keys()) + t = time.monotonic() + start_times = {f: t for f in pending} + end_times = {} + backups: Dict[asyncio.Future, asyncio.Future] = {} + + while pending: + finished, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED, timeout=2 + ) + for task in finished: + # TODO: use exception groups in Python 3.11 to handle case of multiple task exceptions + if task.exception(): + # if the task has a backup that is not done, or is done with no exception, then don't raise this exception + backup = backups.get(task, None) + if backup: + if not backup.done() or not backup.exception(): + continue + raise task.exception() + end_times[task] = time.monotonic() + if return_stats: + result, stats = task.result() + if name is not None: + stats["array_name"] = name + stats["task_create_tstamp"] = task_create_tstamp + yield result, stats + else: + yield task.result() + + # remove any backup task + if use_backups: + backup = backups.get(task, None) + if backup: + if backup in pending: + pending.remove(backup) + del backups[task] + del backups[backup] + backup.cancel() + + if use_backups: + now = time.monotonic() + for task in copy.copy(pending): + if task not in backups and should_launch_backup( + task, now, start_times, end_times + ): + # launch backup task + print("Launching backup task") + i = tasks[task] + i, new_task = create_backup_futures_func([i], **kwargs)[0] + tasks[new_task] = i + start_times[new_task] = time.monotonic() + pending.add(new_task) + backups[task] = new_task + backups[new_task] = task + + if batch_size is not None and len(pending) < batch_size: + inputs = next(input_batches, None) # type: ignore + if inputs is not None: + new_tasks = { + task: i for i, task in create_futures_func(inputs, **kwargs) + } + tasks.update(new_tasks) + pending.update(new_tasks.keys()) + t = time.monotonic() + start_times = {f: t for f in new_tasks.keys()} diff --git a/cubed/runtime/executors/modal_async.py b/cubed/runtime/executors/modal_async.py index 275a02c5e..e02c43c84 100644 --- a/cubed/runtime/executors/modal_async.py +++ b/cubed/runtime/executors/modal_async.py @@ -1,9 +1,7 @@ import asyncio -import copy import time from asyncio.exceptions import TimeoutError -from functools import partial -from typing import Any, AsyncIterator, Dict, Iterable, Optional, Sequence +from typing import Any, AsyncIterator, Iterable, Optional, Sequence from aiostream import stream from modal.exception import ConnectionError @@ -13,7 +11,7 @@ from cubed.core.array import Callback, Spec from cubed.core.plan import visit_node_generations, visit_nodes -from cubed.runtime.backup import should_launch_backup +from cubed.runtime.executors.asyncio import async_map_unordered from cubed.runtime.executors.modal import ( Container, check_runtime_memory, @@ -30,6 +28,7 @@ async def map_unordered( input: Iterable[Any], use_backups: bool = False, backup_function: Optional[Function] = None, + batch_size: Optional[int] = None, return_stats: bool = False, name: Optional[str] = None, **kwargs, @@ -45,9 +44,9 @@ async def map_unordered( :return: Function values (and optionally stats) as they are completed, not necessarily in the input order. """ - task_create_tstamp = time.time() - if not use_backups: + if not use_backups and batch_size is None: + task_create_tstamp = time.time() async for result in app_function.map(input, order_outputs=False, kwargs=kwargs): if return_stats: result, stats = result @@ -59,86 +58,45 @@ async def map_unordered( yield result return + def create_futures_func(input, **kwargs): + return [ + (i, asyncio.ensure_future(app_function.call.aio(i, **kwargs))) + for i in input + ] + backup_function = backup_function or app_function - tasks = { - asyncio.ensure_future(app_function.call.aio(i, **kwargs)): i for i in input - } - pending = set(tasks.keys()) - t = time.monotonic() - start_times = {f: t for f in pending} - end_times = {} - backups: Dict[asyncio.Future, asyncio.Future] = {} - - while pending: - finished, pending = await asyncio.wait( - pending, return_when=asyncio.FIRST_COMPLETED, timeout=2 - ) - for task in finished: - # TODO: use exception groups in Python 3.11 to handle case of multiple task exceptions - if task.exception(): - # if the task has a backup that is not done, or is done with no exception, then don't raise this exception - backup = backups.get(task, None) - if backup: - if not backup.done() or not backup.exception(): - continue - raise task.exception() - end_times[task] = time.monotonic() - if return_stats: - result, stats = task.result() - if name is not None: - stats["array_name"] = name - stats["task_create_tstamp"] = task_create_tstamp - yield result, stats - else: - yield task.result() - - # remove any backup task - if use_backups: - backup = backups.get(task, None) - if backup: - if backup in pending: - pending.remove(backup) - del backups[task] - del backups[backup] - backup.cancel() - - if use_backups: - now = time.monotonic() - for task in copy.copy(pending): - if task not in backups and should_launch_backup( - task, now, start_times, end_times - ): - # launch backup task - print("Launching backup task") - i = tasks[task] - new_task = asyncio.ensure_future( - backup_function.call.aio(i, **kwargs) - ) - tasks[new_task] = i - start_times[new_task] = time.monotonic() - pending.add(new_task) - backups[task] = new_task - backups[new_task] = task + def create_backup_futures_func(input, **kwargs): + return [ + (i, asyncio.ensure_future(backup_function.call.aio(i, **kwargs))) + for i in input + ] + + async for result in async_map_unordered( + create_futures_func, + input, + use_backups=use_backups, + create_backup_futures_func=create_backup_futures_func, + batch_size=batch_size, + return_stats=return_stats, + name=name, + **kwargs, + ): + yield result def pipeline_to_stream(app_function, name, pipeline, **kwargs): - it = stream.iterate( - [ - partial( - map_unordered, - app_function, - pipeline.mappable, - return_stats=True, - name=name, - func=pipeline.function, - config=pipeline.config, - **kwargs, - ) - ] + return stream.iterate( + map_unordered( + app_function, + pipeline.mappable, + return_stats=True, + name=name, + func=pipeline.function, + config=pipeline.config, + **kwargs, + ) ) - # concat stages, running only one stage at a time - return stream.concatmap(it, lambda f: f(), task_limit=1) # This just retries the initial connection attempt, not the function calls diff --git a/cubed/runtime/executors/python_async.py b/cubed/runtime/executors/python_async.py index 57fedd76b..588c89a1d 100644 --- a/cubed/runtime/executors/python_async.py +++ b/cubed/runtime/executors/python_async.py @@ -1,5 +1,4 @@ import asyncio -import time from concurrent.futures import Executor, ThreadPoolExecutor from functools import partial from typing import Any, AsyncIterator, Callable, Iterable, Optional, Sequence @@ -12,6 +11,7 @@ from cubed.core.array import Callback, Spec from cubed.core.plan import visit_node_generations, visit_nodes from cubed.primitive.types import CubedPipeline +from cubed.runtime.executors.asyncio import async_map_unordered from cubed.runtime.types import DagExecutor from cubed.runtime.utils import execution_stats, handle_callbacks @@ -29,65 +29,55 @@ async def map_unordered( input: Iterable[Any], retries: int = 2, use_backups: bool = False, + batch_size: Optional[int] = None, return_stats: bool = False, name: Optional[str] = None, **kwargs, ) -> AsyncIterator[Any]: - if name is not None: - print(f"{name}: running map_unordered") if retries == 0: retrying_function = function else: retryer = Retrying(reraise=True, stop=stop_after_attempt(retries + 1)) retrying_function = partial(retryer, function) - task_create_tstamp = time.time() - tasks = { - asyncio.wrap_future( - concurrent_executor.submit(retrying_function, i, **kwargs) - ): i - for i in input - } - pending = set(tasks.keys()) - - while pending: - finished, pending = await asyncio.wait( - pending, return_when=asyncio.FIRST_COMPLETED, timeout=2 - ) - for task in finished: - # TODO: use exception groups in Python 3.11 to handle case of multiple task exceptions - if task.exception(): - raise task.exception() - if return_stats: - result, stats = task.result() - if name is not None: - stats["array_name"] = name - stats["task_create_tstamp"] = task_create_tstamp - yield result, stats - else: - yield task.result() + def create_futures_func(input, **kwargs): + return [ + ( + i, + asyncio.wrap_future( + concurrent_executor.submit(retrying_function, i, **kwargs) + ), + ) + for i in input + ] + + async for result in async_map_unordered( + create_futures_func, + input, + use_backups=use_backups, + batch_size=batch_size, + return_stats=return_stats, + name=name, + **kwargs, + ): + yield result def pipeline_to_stream( concurrent_executor: Executor, name: str, pipeline: CubedPipeline, **kwargs ) -> Stream: - it = stream.iterate( - [ - partial( - map_unordered, - concurrent_executor, - run_func, - pipeline.mappable, - return_stats=True, - name=name, - func=pipeline.function, - config=pipeline.config, - **kwargs, - ) - ] + return stream.iterate( + map_unordered( + concurrent_executor, + run_func, + pipeline.mappable, + return_stats=True, + name=name, + func=pipeline.function, + config=pipeline.config, + **kwargs, + ) ) - # concat stages, running only one stage at a time - return stream.concatmap(it, lambda f: f(), task_limit=1) async def async_execute_dag( @@ -99,7 +89,8 @@ async def async_execute_dag( compute_arrays_in_parallel: Optional[bool] = None, **kwargs, ) -> None: - with ThreadPoolExecutor() as concurrent_executor: + concurrent_executor = ThreadPoolExecutor() + try: if not compute_arrays_in_parallel: # run one pipeline at a time for name, node in visit_nodes(dag, resume=resume): @@ -123,6 +114,10 @@ async def async_execute_dag( async for _, stats in streamer: handle_callbacks(callbacks, stats) + finally: + # don't wait for any cancelled tasks + concurrent_executor.shutdown(wait=False) + class AsyncPythonDagExecutor(DagExecutor): """An execution engine that uses Python asyncio.""" diff --git a/cubed/runtime/utils.py b/cubed/runtime/utils.py index a8d31856f..8ffa32d64 100644 --- a/cubed/runtime/utils.py +++ b/cubed/runtime/utils.py @@ -1,5 +1,6 @@ import time from functools import partial +from itertools import islice from cubed.utils import peak_measured_mem @@ -52,3 +53,13 @@ def handle_callbacks(callbacks, stats): else: event = TaskEndEvent(**stats) [callback.on_task_end(event) for callback in callbacks] + + +# this will be in Python 3.12 https://docs.python.org/3.12/library/itertools.html#itertools.batched +def batched(iterable, n): + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch diff --git a/cubed/tests/runtime/test_modal_async.py b/cubed/tests/runtime/test_modal_async.py index 503612711..46a45cdb5 100644 --- a/cubed/tests/runtime/test_modal_async.py +++ b/cubed/tests/runtime/test_modal_async.py @@ -1,3 +1,5 @@ +import itertools + import pytest modal = pytest.importorskip("modal") @@ -36,6 +38,11 @@ def deterministic_failure_modal(i, path=None, timing_map=None): return deterministic_failure(path, timing_map, i) +@stub.function(image=image, secret=modal.Secret.from_name("my-aws-secret"), timeout=10) +def deterministic_failure_modal_no_retries(i, path=None, timing_map=None): + return deterministic_failure(path, timing_map, i) + + @stub.function( image=image, secret=modal.Secret.from_name("my-aws-secret"), retries=2, timeout=300 ) @@ -43,11 +50,15 @@ def deterministic_failure_modal_long_timeout(i, path=None, timing_map=None): return deterministic_failure(path, timing_map, i) -async def run_test(app_function, input, use_backups=False, **kwargs): +async def run_test(app_function, input, use_backups=False, batch_size=None, **kwargs): outputs = set() async with stub.run(): async for output in map_unordered( - app_function, input, use_backups=use_backups, **kwargs + app_function, + input, + use_backups=use_backups, + batch_size=batch_size, + **kwargs, ): outputs.add(output) return outputs @@ -173,3 +184,25 @@ def test_stragglers(timing_map, n_tasks, retries, expected_invocation_counts_ove finally: fs = fsspec.open(tmp_path).fs fs.rm(tmp_path, recursive=True) + + +@pytest.mark.cloud +def test_batch(tmp_path): + # input is unbounded, so if entire input were consumed and not read + # in batches then it would never return, since it would never + # run the first (failing) input + try: + with pytest.raises(RuntimeError): + asyncio.run( + run_test( + app_function=deterministic_failure_modal_no_retries, + input=itertools.count(), + path=tmp_path, + timing_map={0: [-1]}, + batch_size=10, + ) + ) + + finally: + fs = fsspec.open(tmp_path).fs + fs.rm(tmp_path, recursive=True) diff --git a/cubed/tests/runtime/test_python_async.py b/cubed/tests/runtime/test_python_async.py index 979f86b10..3e27d7974 100644 --- a/cubed/tests/runtime/test_python_async.py +++ b/cubed/tests/runtime/test_python_async.py @@ -1,4 +1,5 @@ import asyncio +import itertools from concurrent.futures import ThreadPoolExecutor from functools import partial @@ -8,13 +9,21 @@ from cubed.tests.runtime.utils import check_invocation_counts, deterministic_failure -async def run_test(function, input, retries=2): +async def run_test(function, input, retries=2, use_backups=False, batch_size=None): outputs = set() - with ThreadPoolExecutor() as concurrent_executor: + concurrent_executor = ThreadPoolExecutor() + try: async for output in map_unordered( - concurrent_executor, function, input, retries=retries + concurrent_executor, + function, + input, + retries=retries, + use_backups=use_backups, + batch_size=batch_size, ): outputs.add(output) + finally: + concurrent_executor.shutdown(wait=False) return outputs @@ -64,3 +73,42 @@ def test_failure(tmp_path, timing_map, n_tasks, retries): ) check_invocation_counts(tmp_path, timing_map, n_tasks, retries) + + +# fmt: off +@pytest.mark.parametrize( + "timing_map, n_tasks, retries", + [ + ({0: [60]}, 10, 2), + ], +) +# fmt: on +@pytest.mark.skip(reason="This passes, but Python will not exit until the slow task is done.") +def test_stragglers(tmp_path, timing_map, n_tasks, retries): + outputs = asyncio.run( + run_test( + function=partial(deterministic_failure, tmp_path, timing_map), + input=range(n_tasks), + retries=retries, + use_backups=True, + ) + ) + + assert outputs == set(range(n_tasks)) + + check_invocation_counts(tmp_path, timing_map, n_tasks, retries) + + +def test_batch(tmp_path): + # input is unbounded, so if entire input were consumed and not read + # in batches then it would never return, since it would never + # run the first (failing) input + with pytest.raises(RuntimeError): + asyncio.run( + run_test( + function=partial(deterministic_failure, tmp_path, {0: [-1]}), + input=itertools.count(), + retries=0, + batch_size=10, + ) + )