Skip to content

Commit

Permalink
spans
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 30, 2023
1 parent 3286d2f commit 33f0f1e
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 24 deletions.
10 changes: 9 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerExtension
from distributed.spans import SpansExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
All,
Expand Down Expand Up @@ -170,6 +171,7 @@
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerExtension,
"spans": SpansExtension,
"stealing": WorkStealing,
}

Expand Down Expand Up @@ -4410,6 +4412,12 @@ def update_graph(
recommendations[ts.key] = "erred"
break

spans_ext: SpansExtension | None = self.extensions.get("spans")
if spans_ext:
span_annotations = spans_ext.new_tasks(new_tasks)
if span_annotations:
resolved_annotations["span"] = span_annotations

for plugin in list(self.plugins.values()):
try:
plugin.update_graph(
Expand Down Expand Up @@ -4514,7 +4522,7 @@ def _parse_and_apply_annotations(
...
}
"""
resolved_annotations: dict[str, dict[str, Any]] = defaultdict(dict)
resolved_annotations: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for ts in tasks:
key = ts.key
# This could be a typed dict
Expand Down
137 changes: 137 additions & 0 deletions distributed/spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

import dask.config

if TYPE_CHECKING:
from distributed import Scheduler
from distributed.scheduler import TaskState


@contextmanager
def span(*tags: str) -> Iterator[None]:
"""Tag group of tasks to be part of a certain group, called a span.
This context manager can be nested, thus creating sub-spans.
Every cluster defines a "default" span when no span has been defined by the client.
Examples
--------
>>> import dask.array as da
>>> import distributed
>>> client = distributed.Client()
>>> with span("my_workflow"):
... with span("phase 1"):
... a = da.random.random(10)
... b = a + 1
... with span("phase 2"):
... c = b * 2
... d = c.sum()
>>> d.compute()
In the above example,
- Tasks of collections a and b will be annotated on the scheduler and workers with
``{'span': ('my_workflow', 'phase 1')}``
- Tasks of collection c (that aren't already part of a or b) will be annotated with
``{'span': ('my_workflow', 'phase 2')}``
- Tasks of collection d (that aren't already part of a, b, or c) will *not* be
annotated but will nonetheless be attached to span ``('default', )``
You may also set more than one tag at once; e.g.
>>> with span("workflow1", "version1"):
... ...
Note
----
Spans are based on annotations, and just like annotations they can be lost during
optimization. Set config ``optimizatione.fuse.active: false`` to prevent this issue.
"""
prev_id = dask.config.get("annotations.span", ())
with dask.config.set({"annotations.span": prev_id + tags}):
yield


class Span:
#: (<tag>, <tag>, ...)
#: Matches ``TaskState.annotations["span"]``, both on the scheduler and the worker,
#: as well as ``TaskGroup.span``.
#: Tasks with no 'span' annotation will be attached to Span ``("default", )``.
id: tuple[str, ...]

#: Direct children of this span tree
#: Note: you can get the parent through
#: ``distributed.extensions["spans"].spans[self.id[:-1]]``
children: set[Span]

__slots__ = tuple(__annotations__)

def __init__(self, span_id: tuple[str, ...]):
self.id = span_id
self.children = set()

def __repr__(self) -> str:
return f"Span{self.id}"


class SpansExtension:
"""Scheduler extension for spans support"""

#: All Span objects by span_id
spans: dict[tuple[str, ...], Span]

#: Only the spans that don't have any parents {client_id: Span}.
#: This is a convenience helper structure to speed up searches.
root_spans: dict[str, Span]

#: All spans, keyed by the individual tags that make up their span_id.
#: This is a convenience helper structure to speed up searches.
spans_search_by_tag: defaultdict[str, set[Span]]

def __init__(self, scheduler: Scheduler):
self.spans = {}
self.root_spans = {}
self.spans_search_by_tag = defaultdict(set)

def new_tasks(self, tss: Iterable[TaskState]) -> dict[str, tuple[str, ...]]:
"""Acknowledge the creation of new tasks on the scheduler.
Attach tasks to either the desired span or to ("default", ).
Update TaskState.annotations["span"].
Returns
-------
{task key: span id}, only for tasks that explicitly define a span
"""
out = {}
for ts in tss:
span_id = ts.annotations.get("span", ())
assert isinstance(span_id, tuple)
if span_id:
ts.annotations["span"] = out[ts.key] = span_id
else:
span_id = ("default",)
self._ensure_span(span_id)

return out

def _ensure_span(self, span_id: tuple[str, ...]) -> Span:
"""Create Span if it doesn't exist and return it"""
try:
return self.spans[span_id]
except KeyError:
pass

span = self.spans[span_id] = Span(span_id)
for tag in span_id:
self.spans_search_by_tag[tag].add(span)
if len(span_id) > 1:
parent = self._ensure_span(span_id[:-1])
parent.children.add(span)
else:
self.root_spans[span_id[0]] = span

return span
49 changes: 26 additions & 23 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6886,9 +6886,9 @@ async def test_annotations_task_state(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(
{"qux": "bar", "priority": 100} == ts.annotations for ts in s.tasks.values()
)
for ts in s.tasks.values():
assert ts.annotations["qux"] == "bar"
assert ts.annotations["priority"] == 100


@pytest.mark.parametrize("fn", ["compute", "persist"])
Expand All @@ -6904,7 +6904,8 @@ async def test_annotations_compute_time(c, s, a, b, fn):

await wait(fut)
assert s.tasks
assert all(ts.annotations == {"foo": "bar"} for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.annotations["foo"] == "bar"


@pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7036")
Expand Down Expand Up @@ -6936,9 +6937,9 @@ async def test_annotations_priorities(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all("15" in str(ts.priority) for ts in s.tasks.values())
assert all(ts.priority[0] == -15 for ts in s.tasks.values())
assert all({"priority": 15} == ts.annotations for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.priority[0] == -15
assert ts.annotations["priority"] == 15


@gen_cluster(client=True)
Expand All @@ -6951,8 +6952,10 @@ async def test_annotations_workers(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all({"workers": [a.address]} == ts.annotations for ts in s.tasks.values())
assert all({a.address} == ts.worker_restrictions for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.annotations["workers"] == [a.address]
assert ts.worker_restrictions == {a.address}

assert a.data
assert not b.data

Expand All @@ -6967,8 +6970,9 @@ async def test_annotations_retries(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(ts.retries == 2 for ts in s.tasks.values())
assert all(ts.annotations == {"retries": 2} for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.retries == 2
assert ts.annotations["retries"] == 2


@gen_cluster(client=True)
Expand Down Expand Up @@ -7018,8 +7022,9 @@ async def test_annotations_resources(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all([{"GPU": 1} == ts.resource_restrictions for ts in s.tasks.values()])
assert all([{"resources": {"GPU": 1}} == ts.annotations for ts in s.tasks.values()])
for ts in s.tasks.values():
assert ts.resource_restrictions == {"GPU": 1}
assert ts.annotations["resources"] == {"GPU": 1}


@gen_cluster(
Expand Down Expand Up @@ -7054,14 +7059,11 @@ async def test_annotations_loose_restrictions(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(not ts.worker_restrictions for ts in s.tasks.values())
assert all({"fake"} == ts.host_restrictions for ts in s.tasks.values())
assert all(
[
{"workers": ["fake"], "allow_other_workers": True} == ts.annotations
for ts in s.tasks.values()
]
)
for ts in s.tasks.values():
assert not ts.worker_restrictions
assert ts.host_restrictions == {"fake"}
assert ts.annotations["workers"] == ["fake"]
assert ts.annotations["allow_other_workers"] is True


@gen_cluster(
Expand All @@ -7079,8 +7081,9 @@ async def test_annotations_submit_map(c, s, a, b):

await wait([f, *fs])

assert all([{"foo": 1} == ts.resource_restrictions for ts in s.tasks.values()])
assert all([{"resources": {"foo": 1}} == ts.annotations for ts in s.tasks.values()])
for ts in s.tasks.values():
assert ts.resource_restrictions == {"foo": 1}
assert ts.annotations["resources"] == {"foo": 1}
assert not b.state.tasks


Expand Down
101 changes: 101 additions & 0 deletions distributed/tests/test_spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

from dask import delayed

from distributed.spans import span
from distributed.utils_test import async_poll_for, gen_cluster, inc


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_spans(c, s, a):
x = delayed(inc)(1) # Default span
with span("my workflow"):
with span("p1"):
y = x + 1

@span("p2")
def f(i):
return i * 2

z = f(y)

zp = c.persist(z)
assert await c.compute(zp) == 6

ext = s.extensions["spans"]

assert s.tasks[x.key].annotations == {}
assert s.tasks[y.key].annotations == {"span": ("my workflow", "p1")}
assert s.tasks[z.key].annotations == {"span": ("my workflow", "p2")}

assert a.state.tasks[x.key].annotations == {}
assert a.state.tasks[y.key].annotations == {"span": ("my workflow", "p1")}
assert a.state.tasks[z.key].annotations == {"span": ("my workflow", "p2")}

assert ext.spans.keys() == {
("default",),
("my workflow",),
("my workflow", "p1"),
("my workflow", "p2"),
}
for k, sp in ext.spans.items():
assert sp.id == k

default = ext.spans["default",]
mywf = ext.spans["my workflow",]
p1 = ext.spans["my workflow", "p1"]
p2 = ext.spans["my workflow", "p2"]

assert default.children == set()
assert mywf.children == {p1, p2}
assert p1.children == set()
assert p2.children == set()

assert str(default) == "Span('default',)"
assert str(p1) == "Span('my workflow', 'p1')"
assert ext.root_spans == {"default": default, "my workflow": mywf}
assert ext.spans_search_by_tag["my workflow"] == {mywf, p1, p2}

assert s.tasks[x.key].annotations == {}
assert s.tasks[y.key].annotations["span"] == ("my workflow", "p1")

# Test that spans survive their tasks
del zp
await async_poll_for(lambda: not s.tasks, timeout=5)
assert ext.spans.keys() == {
("default",),
("my workflow",),
("my workflow", "p1"),
("my workflow", "p2"),
}


@gen_cluster(client=True)
async def test_submit(c, s, a, b):
x = c.submit(inc, 1, key="x")
with span("foo"):
y = c.submit(inc, 2, key="y")
assert await x == 2
assert await y == 3

assert "span" not in s.tasks["x"].annotations
assert s.tasks["y"].annotations["span"] == ("foo",)
assert s.extensions["spans"].spans.keys() == {("default",), ("foo",)}


@gen_cluster(client=True)
async def test_multiple_tags(c, s, a, b):
with span("foo", "bar"):
x = c.submit(inc, 1, key="x")
assert await x == 2

assert s.tasks["x"].annotations["span"] == ("foo", "bar")
assert s.extensions["spans"].spans_search_by_tag.keys() == {"foo", "bar"}


@gen_cluster(client=True, scheduler_kwargs={"extensions": {}})
async def test_no_extension(c, s, a, b):
x = c.submit(inc, 1, key="x")
assert await x == 2
assert "spans" not in s.extensions
assert s.tasks["x"].annotations == {}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ allow_incomplete_defs = true
# Recent or recently overhauled modules featuring stricter validation
module = [
"distributed.active_memory_manager",
"distributed.spans",
"distributed.system_monitor",
"distributed.worker_memory",
"distributed.worker_state_machine",
Expand Down

0 comments on commit 33f0f1e

Please sign in to comment.