Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask distributed async executor #279

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ omit =
cubed/extensions/*
cubed/runtime/executors/beam.py
cubed/runtime/executors/coiled.py
cubed/runtime/executors/dask.py
cubed/runtime/executors/dask*.py
cubed/runtime/executors/lithops.py
cubed/runtime/executors/modal*.py
cubed/vendor/*
154 changes: 154 additions & 0 deletions cubed/runtime/executors/dask_distributed_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import asyncio
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)

from aiostream import stream
from aiostream.core import Stream
from dask.distributed import Client
from networkx import MultiDiGraph

from cubed.core.array import Callback, Spec
from cubed.core.plan import visit_node_generations, visit_nodes
from cubed.primitive.types import CubedPipeline
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.types import DagExecutor
from cubed.runtime.utils import execution_stats, gensym, handle_callbacks


# note we can't call `pipeline_func` just `func` here as it clashes with `dask.distributed.Client.map``
@execution_stats
def run_func(input, pipeline_func=None, config=None, name=None):
result = pipeline_func(input, config=config)
return result


async def map_unordered(
client: Client,
map_function: Callable[..., Any],
map_iterdata: Iterable[Union[List[Any], Tuple[Any, ...], Dict[str, Any]]],
retries: int = 2,
use_backups: bool = False,
return_stats: bool = False,
name: Optional[str] = None,
**kwargs,
) -> AsyncIterator[Any]:
def create_futures_func(input, **kwargs):
input = list(input) # dask expects a sequence (it calls `len` on it)
key = name or gensym("map")
key = key.replace("-", "_") # otherwise array number is not shown on dashboard
return [
(i, asyncio.ensure_future(f))
for i, f in zip(
input,
client.map(map_function, input, key=key, retries=retries, **kwargs),
)
]

def create_backup_futures_func(input, **kwargs):
input = list(input) # dask expects a sequence (it calls `len` on it)
key = name or gensym("backup")
key = key.replace("-", "_") # otherwise array number is not shown on dashboard
return [
(i, asyncio.ensure_future(f))
for i, f in zip(input, client.map(map_function, input, key=key, **kwargs))
]

async for result in async_map_unordered(
create_futures_func,
map_iterdata,
use_backups=use_backups,
create_backup_futures_func=create_backup_futures_func,
return_stats=return_stats,
name=name,
**kwargs,
):
yield result


def pipeline_to_stream(
client: Client, name: str, pipeline: CubedPipeline, **kwargs
) -> Stream:
return stream.iterate(
map_unordered(
client,
run_func,
pipeline.mappable,
return_stats=True,
name=name,
pipeline_func=pipeline.function,
config=pipeline.config,
**kwargs,
)
)


async def async_execute_dag(
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
spec: Optional[Spec] = None,
compute_arrays_in_parallel: Optional[bool] = None,
compute_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
compute_kwargs = compute_kwargs or {}
async with Client(asynchronous=True, **compute_kwargs) as client:
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
# run pipelines in the same topological generation in parallel by merging their streams
streams = [
pipeline_to_stream(client, name, node["pipeline"], **kwargs)
for name, node in gen
]
merged_stream = stream.merge(*streams)
async with merged_stream.stream() as streamer:
async for _, stats in streamer:
handle_callbacks(callbacks, stats)


class AsyncDaskDistributedExecutor(DagExecutor):
"""An execution engine that uses Dask Distributed's async API."""

def __init__(self, **kwargs):
self.kwargs = kwargs

def execute_dag(
self,
dag: MultiDiGraph,
callbacks: Optional[Sequence[Callback]] = None,
array_names: Optional[Sequence[str]] = None,
resume: Optional[bool] = None,
spec: Optional[Spec] = None,
compute_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio.run(
async_execute_dag(
dag,
callbacks=callbacks,
array_names=array_names,
resume=resume,
spec=spec,
compute_kwargs=compute_kwargs,
**merged_kwargs,
)
)
100 changes: 100 additions & 0 deletions cubed/tests/runtime/test_dask_distributed_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import asyncio
from functools import partial

import pytest

from cubed.tests.runtime.utils import check_invocation_counts, deterministic_failure

pytest.importorskip("dask.distributed")

from dask.distributed import Client

from cubed.runtime.executors.dask_distributed_async import map_unordered


async def run_test(function, input, retries, use_backups=False):
outputs = set()
async with Client(asynchronous=True) as client:
async for output in map_unordered(
client,
function,
input,
retries=retries,
use_backups=use_backups,
):
outputs.add(output)
return outputs


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
# no failures
({}, 3, 2),
# first invocation fails
({0: [-1], 1: [-1], 2: [-1]}, 3, 2),
# first two invocations fail
({0: [-1, -1], 1: [-1, -1], 2: [-1, -1]}, 3, 2),
# first input sleeps once (not tested since timeout is not supported)
# ({0: [20]}, 3, 2),
],
)
# fmt: on
def test_success(tmp_path, timing_map, n_tasks, retries):
outputs = asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, timing_map),
input=range(n_tasks),
retries=retries,
)
)

assert outputs == set(range(n_tasks))

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
# too many failures
({0: [-1], 1: [-1], 2: [-1, -1, -1]}, 3, 2),
],
)
# fmt: on
def test_failure(tmp_path, timing_map, n_tasks, retries):
with pytest.raises(RuntimeError):
asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, timing_map),
input=range(n_tasks),
retries=retries,
)
)

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)


# fmt: off
@pytest.mark.parametrize(
"timing_map, n_tasks, retries",
[
({0: [60]}, 10, 2),
],
)
# fmt: on
def test_stragglers(tmp_path, timing_map, n_tasks, retries):
outputs = asyncio.run(
run_test(
function=partial(deterministic_failure, tmp_path, timing_map),
input=range(n_tasks),
retries=retries,
use_backups=True,
)
)

assert outputs == set(range(n_tasks))

check_invocation_counts(tmp_path, timing_map, n_tasks, retries)
11 changes: 11 additions & 0 deletions cubed/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@
except ImportError:
pass

try:
from cubed.runtime.executors.dask_distributed_async import (
AsyncDaskDistributedExecutor,
)

ALL_EXECUTORS.append(AsyncDaskDistributedExecutor())

MAIN_EXECUTORS.append(AsyncDaskDistributedExecutor())
except ImportError:
pass

try:
from cubed.runtime.executors.lithops import LithopsDagExecutor

Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ diagnostics = [
]
beam = ["apache-beam", "gcsfs"]
dask = ["dask"]
dask-distributed = ["distributed"]
lithops = ["lithops[aws] >= 2.7.0"]
modal = [
"cubed[diagnostics]",
Expand Down Expand Up @@ -70,6 +71,13 @@ test-dask = [
"pytest-cov",
"pytest-mock",
]
test-dask-distributed = [
"cubed[dask-distributed,diagnostics]",
"dill",
"pytest",
"pytest-cov",
"pytest-mock",
]
test-modal = [
"cubed[modal]",
"dill",
Expand Down