Skip to content

Commit

Permalink
Add async_map_unordered function, shared by Modal and Python async …
Browse files Browse the repository at this point in the history
…executors

Add support for batching inputs to avoid overwhelming the backend service
  • Loading branch information
tomwhite committed Jul 27, 2023
1 parent 9e60420 commit b885825
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 130 deletions.
102 changes: 102 additions & 0 deletions cubed/runtime/executors/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import asyncio
import copy
import time
from asyncio import Future
from typing import Any, AsyncIterator, Callable, Dict, Iterable, List, Optional, Tuple

from cubed.runtime.backup import should_launch_backup
from cubed.runtime.utils import batched


async def async_map_unordered(
create_futures_func: Callable[..., List[Tuple[Any, Future]]],
input: Iterable[Any],
use_backups: bool = False,
create_backup_futures_func: Optional[
Callable[..., List[Tuple[Any, Future]]]
] = None,
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
) -> AsyncIterator[Any]:
"""
Asynchronous parallel map over an iterable input, with support for backups and batching.
"""

if create_backup_futures_func is None:
create_backup_futures_func = create_futures_func

if batch_size is None:
inputs = input
else:
input_batches = batched(input, batch_size)
inputs = next(input_batches)

task_create_tstamp = time.time()
tasks = {task: i for i, task in create_futures_func(inputs, **kwargs)}
pending = set(tasks.keys())
t = time.monotonic()
start_times = {f: t for f in pending}
end_times = {}
backups: Dict[asyncio.Future, asyncio.Future] = {}

while pending:
finished, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED, timeout=2
)
for task in finished:
# TODO: use exception groups in Python 3.11 to handle case of multiple task exceptions
if task.exception():
# if the task has a backup that is not done, or is done with no exception, then don't raise this exception
backup = backups.get(task, None)
if backup:
if not backup.done() or not backup.exception():
continue
raise task.exception()
end_times[task] = time.monotonic()
if return_stats:
result, stats = task.result()
if name is not None:
stats["array_name"] = name
stats["task_create_tstamp"] = task_create_tstamp
yield result, stats
else:
yield task.result()

# remove any backup task
if use_backups:
backup = backups.get(task, None)
if backup:
if backup in pending:
pending.remove(backup)
del backups[task]
del backups[backup]
backup.cancel()

if use_backups:
now = time.monotonic()
for task in copy.copy(pending):
if task not in backups and should_launch_backup(
task, now, start_times, end_times
):
# launch backup task
print("Launching backup task")
i = tasks[task]
i, new_task = create_backup_futures_func([i], **kwargs)[0]
tasks[new_task] = i
start_times[new_task] = time.monotonic()
pending.add(new_task)
backups[task] = new_task
backups[new_task] = task

if batch_size is not None and len(pending) < batch_size:
inputs = next(input_batches, None) # type: ignore
if inputs is not None:
new_tasks = {
task: i for i, task in create_futures_func(inputs, **kwargs)
}
tasks.update(new_tasks)
pending.update(new_tasks.keys())
t = time.monotonic()
start_times = {f: t for f in new_tasks.keys()}
118 changes: 38 additions & 80 deletions cubed/runtime/executors/modal_async.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
import copy
import time
from asyncio.exceptions import TimeoutError
from functools import partial
from typing import Any, AsyncIterator, Dict, Iterable, Optional, Sequence
from typing import Any, AsyncIterator, Iterable, Optional, Sequence

from aiostream import stream
from modal.exception import ConnectionError
Expand All @@ -13,7 +11,7 @@

from cubed.core.array import Callback, Spec
from cubed.core.plan import visit_node_generations, visit_nodes
from cubed.runtime.backup import should_launch_backup
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.executors.modal import (
Container,
check_runtime_memory,
Expand All @@ -30,6 +28,7 @@ async def map_unordered(
input: Iterable[Any],
use_backups: bool = False,
backup_function: Optional[Function] = None,
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
Expand All @@ -45,9 +44,9 @@ async def map_unordered(
:return: Function values (and optionally stats) as they are completed, not necessarily in the input order.
"""
task_create_tstamp = time.time()

if not use_backups:
if not use_backups and batch_size is None:
task_create_tstamp = time.time()
async for result in app_function.map(input, order_outputs=False, kwargs=kwargs):
if return_stats:
result, stats = result
Expand All @@ -59,86 +58,45 @@ async def map_unordered(
yield result
return

def create_futures_func(input, **kwargs):
return [
(i, asyncio.ensure_future(app_function.call.aio(i, **kwargs)))
for i in input
]

backup_function = backup_function or app_function

tasks = {
asyncio.ensure_future(app_function.call.aio(i, **kwargs)): i for i in input
}
pending = set(tasks.keys())
t = time.monotonic()
start_times = {f: t for f in pending}
end_times = {}
backups: Dict[asyncio.Future, asyncio.Future] = {}

while pending:
finished, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED, timeout=2
)
for task in finished:
# TODO: use exception groups in Python 3.11 to handle case of multiple task exceptions
if task.exception():
# if the task has a backup that is not done, or is done with no exception, then don't raise this exception
backup = backups.get(task, None)
if backup:
if not backup.done() or not backup.exception():
continue
raise task.exception()
end_times[task] = time.monotonic()
if return_stats:
result, stats = task.result()
if name is not None:
stats["array_name"] = name
stats["task_create_tstamp"] = task_create_tstamp
yield result, stats
else:
yield task.result()

# remove any backup task
if use_backups:
backup = backups.get(task, None)
if backup:
if backup in pending:
pending.remove(backup)
del backups[task]
del backups[backup]
backup.cancel()

if use_backups:
now = time.monotonic()
for task in copy.copy(pending):
if task not in backups and should_launch_backup(
task, now, start_times, end_times
):
# launch backup task
print("Launching backup task")
i = tasks[task]
new_task = asyncio.ensure_future(
backup_function.call.aio(i, **kwargs)
)
tasks[new_task] = i
start_times[new_task] = time.monotonic()
pending.add(new_task)
backups[task] = new_task
backups[new_task] = task
def create_backup_futures_func(input, **kwargs):
return [
(i, asyncio.ensure_future(backup_function.call.aio(i, **kwargs)))
for i in input
]

async for result in async_map_unordered(
create_futures_func,
input,
use_backups=use_backups,
create_backup_futures_func=create_backup_futures_func,
batch_size=batch_size,
return_stats=return_stats,
name=name,
**kwargs,
):
yield result


def pipeline_to_stream(app_function, name, pipeline, **kwargs):
it = stream.iterate(
[
partial(
map_unordered,
app_function,
pipeline.mappable,
return_stats=True,
name=name,
func=pipeline.function,
config=pipeline.config,
**kwargs,
)
]
return stream.iterate(
map_unordered(
app_function,
pipeline.mappable,
return_stats=True,
name=name,
func=pipeline.function,
config=pipeline.config,
**kwargs,
)
)
# concat stages, running only one stage at a time
return stream.concatmap(it, lambda f: f(), task_limit=1)


# This just retries the initial connection attempt, not the function calls
Expand Down
85 changes: 40 additions & 45 deletions cubed/runtime/executors/python_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import time
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterable, Optional, Sequence
Expand All @@ -12,6 +11,7 @@
from cubed.core.array import Callback, Spec
from cubed.core.plan import visit_node_generations, visit_nodes
from cubed.primitive.types import CubedPipeline
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.types import DagExecutor
from cubed.runtime.utils import execution_stats, handle_callbacks

Expand All @@ -29,65 +29,55 @@ async def map_unordered(
input: Iterable[Any],
retries: int = 2,
use_backups: bool = False,
batch_size: Optional[int] = None,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
) -> AsyncIterator[Any]:
if name is not None:
print(f"{name}: running map_unordered")
if retries == 0:
retrying_function = function
else:
retryer = Retrying(reraise=True, stop=stop_after_attempt(retries + 1))
retrying_function = partial(retryer, function)

task_create_tstamp = time.time()
tasks = {
asyncio.wrap_future(
concurrent_executor.submit(retrying_function, i, **kwargs)
): i
for i in input
}
pending = set(tasks.keys())

while pending:
finished, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED, timeout=2
)
for task in finished:
# TODO: use exception groups in Python 3.11 to handle case of multiple task exceptions
if task.exception():
raise task.exception()
if return_stats:
result, stats = task.result()
if name is not None:
stats["array_name"] = name
stats["task_create_tstamp"] = task_create_tstamp
yield result, stats
else:
yield task.result()
def create_futures_func(input, **kwargs):
return [
(
i,
asyncio.wrap_future(
concurrent_executor.submit(retrying_function, i, **kwargs)
),
)
for i in input
]

async for result in async_map_unordered(
create_futures_func,
input,
use_backups=use_backups,
batch_size=batch_size,
return_stats=return_stats,
name=name,
**kwargs,
):
yield result


def pipeline_to_stream(
concurrent_executor: Executor, name: str, pipeline: CubedPipeline, **kwargs
) -> Stream:
it = stream.iterate(
[
partial(
map_unordered,
concurrent_executor,
run_func,
pipeline.mappable,
return_stats=True,
name=name,
func=pipeline.function,
config=pipeline.config,
**kwargs,
)
]
return stream.iterate(
map_unordered(
concurrent_executor,
run_func,
pipeline.mappable,
return_stats=True,
name=name,
func=pipeline.function,
config=pipeline.config,
**kwargs,
)
)
# concat stages, running only one stage at a time
return stream.concatmap(it, lambda f: f(), task_limit=1)


async def async_execute_dag(
Expand All @@ -99,7 +89,8 @@ async def async_execute_dag(
compute_arrays_in_parallel: Optional[bool] = None,
**kwargs,
) -> None:
with ThreadPoolExecutor() as concurrent_executor:
concurrent_executor = ThreadPoolExecutor()
try:
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
Expand All @@ -123,6 +114,10 @@ async def async_execute_dag(
async for _, stats in streamer:
handle_callbacks(callbacks, stats)

finally:
# don't wait for any cancelled tasks
concurrent_executor.shutdown(wait=False)


class AsyncPythonDagExecutor(DagExecutor):
"""An execution engine that uses Python asyncio."""
Expand Down
Loading

0 comments on commit b885825

Please sign in to comment.