Skip to content

Commit

Permalink
spans
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 30, 2023
1 parent e727d19 commit 7d31b76
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 28 deletions.
3 changes: 2 additions & 1 deletion distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def update_graph( # type: ignore
assert not kwargs
assert keys == {"foo"}
assert tasks == ["foo"]
assert annotations == {}
assert annotations == {"span": {"foo": (c.id,)}}
assert len(priority) == 1
assert isinstance(priority["foo"], tuple)
assert dependencies == {"foo": set()}
Expand Down Expand Up @@ -467,6 +467,7 @@ def update_graph( # type: ignore
"layer": {"f2": "explicit"},
"len_key": {"f3": 2},
"priority": {"f2": 13},
"span": {k: (c.id,) for k in tasks},
}
assert len(priority) == len(tasks), priority
assert priority["f2"][0] == -13
Expand Down
17 changes: 13 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerExtension
from distributed.spans import SpansExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
All,
Expand Down Expand Up @@ -169,6 +170,7 @@
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerExtension,
"spans": SpansExtension,
"stealing": WorkStealing,
}

Expand Down Expand Up @@ -4362,6 +4364,7 @@ def update_graph(
# required to satisfy the current plugin API. This should be
# reconsidered.
resolved_annotations = self._parse_and_apply_annotations(
client=client,
tasks=new_tasks,
annotations=annotations,
layer_annotations=layer_annotations,
Expand Down Expand Up @@ -4480,6 +4483,7 @@ def _generate_taskstates(

def _parse_and_apply_annotations(
self,
client: str,
tasks: Iterable[TaskState],
annotations: dict,
layer_annotations: dict[str, dict],
Expand Down Expand Up @@ -4513,14 +4517,19 @@ def _parse_and_apply_annotations(
...
}
"""
resolved_annotations: dict[str, dict[str, Any]] = defaultdict(dict)
resolved_annotations: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for ts in tasks:
key = ts.key
# This could be a typed dict
if not annotations and key not in layer_annotations:
continue
out = annotations.copy()
out.update(layer_annotations.get(key, {}))

spans_ext: SpansExtension | None = self.extensions.get("spans")
if spans_ext:
span_id = out.get("span", ())
assert isinstance(span_id, (list, tuple))
out["span"] = span_id = (client, *span_id)
spans_ext.ensure_span(span_id)

for annot, value in out.items():
# Pop the key since names don't always match attributes
if callable(value):
Expand Down
107 changes: 107 additions & 0 deletions distributed/spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

import dask.config

if TYPE_CHECKING:
from distributed import Scheduler


@contextmanager
def span(*tags: str) -> Iterator[None]:
"""Tag group of tasks to be part of a certain group, called a span.
This context manager can be nested, thus creating sub-spans.
Each dask.distributed Client automatically defines a root span, which is its own
random client ID.
Examples
--------
>>> import dask.array as da
>>> import distributed
>>> client = distributed.Client()
>>> with span("my_workflow"):
... with span("phase 1"):
... a = da.random.random(10)
... b = a + 1
... with span("phase 2"):
... c = b * 2
>>> c.compute()
In the above example,
- Tasks of collections a and b will be annotated on the scheduler and workers with
``{'span': ('Client-6e31a38d-fbe3-11ed-83dd-b42e99c1ab7d', 'my_workflow', 'phase 1')}``
- Tasks of collection c (that aren't already part of a or b) will be annotated with
``{'span': ('Client-6e31a38d-fbe3-11ed-83dd-b42e99c1ab7d', 'my_workflow', 'phase 2')}``
The client ID will change randomly every time the client is reinitialized.
You may also set more than one tag at once; e.g.
>>> with span("workflow1", "version1"):
... ...
Note
----
Spans are based on annotations, and just like annotations they can be lost during
optimization. Set config ``optimize.fuse.active: false`` to prevent this issue.
"""
prev_id = dask.config.get("annotations.span", ())
with dask.config.set({"annotations.span": prev_id + tags}):
yield


class Span:
id: tuple[str, ...]
children: set[Span]

__slots__ = tuple(__annotations__)

def __init__(self, span_id: tuple[str, ...]):
self.id = span_id
self.children = set()

def __repr__(self) -> str:
return f"Span{self.id}"


class SpansExtension:
"""Scheduler extension for spans support"""

#: All Span objects by span_id
spans: dict[tuple[str, ...], Span]

#: Only the spans that don't have any parents {client_id: Span}.
#: This is a convenience helper structure to speed up searches.
root_spans: dict[str, Span]

#: All spans, keyed by the individual tags that make up their span_id.
#: This is a convenience helper structure to speed up searches.
spans_search_by_tag: defaultdict[str, set[Span]]

def __init__(self, scheduler: Scheduler):
self.spans = {}
self.root_spans = {}
self.spans_search_by_tag = defaultdict(set)

def ensure_span(self, span_id: tuple[str, ...]) -> Span:
"""Create Span if it doesn't exist and return it"""
try:
return self.spans[span_id]
except KeyError:
pass

span = self.spans[span_id] = Span(span_id)
for tag in span_id:
self.spans_search_by_tag[tag].add(span)
if len(span_id) > 1:
parent = self.ensure_span(span_id[:-1])
parent.children.add(span)
else:
self.root_spans[span_id[0]] = span

return span
49 changes: 26 additions & 23 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6885,9 +6885,9 @@ async def test_annotations_task_state(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(
{"qux": "bar", "priority": 100} == ts.annotations for ts in s.tasks.values()
)
for ts in s.tasks.values():
assert ts.annotations["qux"] == "bar"
assert ts.annotations["priority"] == 100


@pytest.mark.parametrize("fn", ["compute", "persist"])
Expand All @@ -6903,7 +6903,8 @@ async def test_annotations_compute_time(c, s, a, b, fn):

await wait(fut)
assert s.tasks
assert all(ts.annotations == {"foo": "bar"} for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.annotations["foo"] == "bar"


@pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7036")
Expand Down Expand Up @@ -6935,9 +6936,9 @@ async def test_annotations_priorities(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all("15" in str(ts.priority) for ts in s.tasks.values())
assert all(ts.priority[0] == -15 for ts in s.tasks.values())
assert all({"priority": 15} == ts.annotations for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.priority[0] == -15
assert ts.annotations["priority"] == 15


@gen_cluster(client=True)
Expand All @@ -6950,8 +6951,10 @@ async def test_annotations_workers(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all({"workers": [a.address]} == ts.annotations for ts in s.tasks.values())
assert all({a.address} == ts.worker_restrictions for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.annotations["workers"] == [a.address]
assert ts.worker_restrictions == {a.address}

assert a.data
assert not b.data

Expand All @@ -6966,8 +6969,9 @@ async def test_annotations_retries(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(ts.retries == 2 for ts in s.tasks.values())
assert all(ts.annotations == {"retries": 2} for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.retries == 2
assert ts.annotations["retries"] == 2


@gen_cluster(client=True)
Expand Down Expand Up @@ -7017,8 +7021,9 @@ async def test_annotations_resources(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all([{"GPU": 1} == ts.resource_restrictions for ts in s.tasks.values()])
assert all([{"resources": {"GPU": 1}} == ts.annotations for ts in s.tasks.values()])
for ts in s.tasks.values():
assert ts.resource_restrictions == {"GPU": 1}
assert ts.annotations["resources"] == {"GPU": 1}


@gen_cluster(
Expand Down Expand Up @@ -7053,14 +7058,11 @@ async def test_annotations_loose_restrictions(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(not ts.worker_restrictions for ts in s.tasks.values())
assert all({"fake"} == ts.host_restrictions for ts in s.tasks.values())
assert all(
[
{"workers": ["fake"], "allow_other_workers": True} == ts.annotations
for ts in s.tasks.values()
]
)
for ts in s.tasks.values():
assert not ts.worker_restrictions
assert ts.host_restrictions == {"fake"}
assert ts.annotations["workers"] == ["fake"]
assert ts.annotations["allow_other_workers"] is True


@gen_cluster(
Expand All @@ -7078,8 +7080,9 @@ async def test_annotations_submit_map(c, s, a, b):

await wait([f, *fs])

assert all([{"foo": 1} == ts.resource_restrictions for ts in s.tasks.values()])
assert all([{"resources": {"foo": 1}} == ts.annotations for ts in s.tasks.values()])
for ts in s.tasks.values():
assert ts.resource_restrictions == {"foo": 1}
assert ts.annotations["resources"] == {"foo": 1}
assert not b.state.tasks


Expand Down
Loading

0 comments on commit 7d31b76

Please sign in to comment.