Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spans skeleton #7862

Merged
merged 2 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerExtension
from distributed.spans import SpansExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
All,
Expand Down Expand Up @@ -170,6 +171,7 @@
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerExtension,
"spans": SpansExtension,
"stealing": WorkStealing,
}

Expand Down Expand Up @@ -4410,6 +4412,12 @@ def update_graph(
recommendations[ts.key] = "erred"
break

spans_ext: SpansExtension | None = self.extensions.get("spans")
if spans_ext:
span_annotations = spans_ext.new_tasks(new_tasks)
if span_annotations:
resolved_annotations["span"] = span_annotations

for plugin in list(self.plugins.values()):
try:
plugin.update_graph(
Expand Down Expand Up @@ -4514,7 +4522,7 @@ def _parse_and_apply_annotations(
...
}
"""
resolved_annotations: dict[str, dict[str, Any]] = defaultdict(dict)
resolved_annotations: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for ts in tasks:
key = ts.key
# This could be a typed dict
Expand Down
137 changes: 137 additions & 0 deletions distributed/spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

import dask.config

if TYPE_CHECKING:
from distributed import Scheduler
from distributed.scheduler import TaskState


@contextmanager
def span(*tags: str) -> Iterator[None]:
"""Tag group of tasks to be part of a certain group, called a span.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring is not rendered anywhere at the moment.
We should wait until we have something actually useful before we advertise it in sphinx.

This context manager can be nested, thus creating sub-spans.
Every cluster defines a global "default" span when no span has been defined by the client.
Examples
--------
>>> import dask.array as da
>>> import distributed
>>> client = distributed.Client()
>>> with span("my_workflow"):
... with span("phase 1"):
... a = da.random.random(10)
... b = a + 1
... with span("phase 2"):
... c = b * 2
... d = c.sum()
>>> d.compute()
In the above example,
- Tasks of collections a and b will be annotated on the scheduler and workers with
``{'span': ('my_workflow', 'phase 1')}``
- Tasks of collection c (that aren't already part of a or b) will be annotated with
``{'span': ('my_workflow', 'phase 2')}``
- Tasks of collection d (that aren't already part of a, b, or c) will *not* be
annotated but will nonetheless be attached to span ``('default', )``
You may also set more than one tag at once; e.g.
>>> with span("workflow1", "version1"):
... ...
Note
----
Spans are based on annotations, and just like annotations they can be lost during
optimization. Set config ``optimizatione.fuse.active: false`` to prevent this issue.
"""
prev_id = dask.config.get("annotations.span", ())
with dask.config.set({"annotations.span": prev_id + tags}):
yield


class Span:
#: (<tag>, <tag>, ...)
#: Matches ``TaskState.annotations["span"]``, both on the scheduler and the worker,
#: as well as ``TaskGroup.span``.
#: Tasks with no 'span' annotation will be attached to Span ``("default", )``.
id: tuple[str, ...]

#: Direct children of this span tree
#: Note: you can get the parent through
#: ``distributed.extensions["spans"].spans[self.id[:-1]]``
children: set[Span]

__slots__ = tuple(__annotations__)

def __init__(self, span_id: tuple[str, ...]):
self.id = span_id
self.children = set()

def __repr__(self) -> str:
return f"Span{self.id}"


class SpansExtension:
"""Scheduler extension for spans support"""

#: All Span objects by span_id
spans: dict[tuple[str, ...], Span]

#: Only the spans that don't have any parents {client_id: Span}.
#: This is a convenience helper structure to speed up searches.
root_spans: dict[str, Span]

#: All spans, keyed by the individual tags that make up their span_id.
#: This is a convenience helper structure to speed up searches.
spans_search_by_tag: defaultdict[str, set[Span]]

def __init__(self, scheduler: Scheduler):
self.spans = {}
self.root_spans = {}
self.spans_search_by_tag = defaultdict(set)

def new_tasks(self, tss: Iterable[TaskState]) -> dict[str, tuple[str, ...]]:
"""Acknowledge the creation of new tasks on the scheduler.
Attach tasks to either the desired span or to ("default", ).
Update TaskState.annotations["span"].
Returns
-------
{task key: span id}, only for tasks that explicitly define a span
"""
out = {}
for ts in tss:
span_id = ts.annotations.get("span", ())
assert isinstance(span_id, tuple)
if span_id:
ts.annotations["span"] = out[ts.key] = span_id
else:
span_id = ("default",)
self._ensure_span(span_id)

return out

def _ensure_span(self, span_id: tuple[str, ...]) -> Span:
"""Create Span if it doesn't exist and return it"""
try:
return self.spans[span_id]
except KeyError:
pass

span = self.spans[span_id] = Span(span_id)
for tag in span_id:
self.spans_search_by_tag[tag].add(span)
if len(span_id) > 1:
parent = self._ensure_span(span_id[:-1])
parent.children.add(span)
else:
self.root_spans[span_id[0]] = span

return span
49 changes: 26 additions & 23 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6886,9 +6886,9 @@ async def test_annotations_task_state(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(
{"qux": "bar", "priority": 100} == ts.annotations for ts in s.tasks.values()
)
for ts in s.tasks.values():
assert ts.annotations["qux"] == "bar"
assert ts.annotations["priority"] == 100


@pytest.mark.parametrize("fn", ["compute", "persist"])
Expand All @@ -6904,7 +6904,8 @@ async def test_annotations_compute_time(c, s, a, b, fn):

await wait(fut)
assert s.tasks
assert all(ts.annotations == {"foo": "bar"} for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.annotations["foo"] == "bar"


@pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7036")
Expand Down Expand Up @@ -6936,9 +6937,9 @@ async def test_annotations_priorities(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all("15" in str(ts.priority) for ts in s.tasks.values())
assert all(ts.priority[0] == -15 for ts in s.tasks.values())
assert all({"priority": 15} == ts.annotations for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.priority[0] == -15
assert ts.annotations["priority"] == 15


@gen_cluster(client=True)
Expand All @@ -6951,8 +6952,10 @@ async def test_annotations_workers(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all({"workers": [a.address]} == ts.annotations for ts in s.tasks.values())
assert all({a.address} == ts.worker_restrictions for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.annotations["workers"] == [a.address]
assert ts.worker_restrictions == {a.address}

assert a.data
assert not b.data

Expand All @@ -6967,8 +6970,9 @@ async def test_annotations_retries(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(ts.retries == 2 for ts in s.tasks.values())
assert all(ts.annotations == {"retries": 2} for ts in s.tasks.values())
for ts in s.tasks.values():
assert ts.retries == 2
assert ts.annotations["retries"] == 2


@gen_cluster(client=True)
Expand Down Expand Up @@ -7018,8 +7022,9 @@ async def test_annotations_resources(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all([{"GPU": 1} == ts.resource_restrictions for ts in s.tasks.values()])
assert all([{"resources": {"GPU": 1}} == ts.annotations for ts in s.tasks.values()])
for ts in s.tasks.values():
assert ts.resource_restrictions == {"GPU": 1}
assert ts.annotations["resources"] == {"GPU": 1}


@gen_cluster(
Expand Down Expand Up @@ -7054,14 +7059,11 @@ async def test_annotations_loose_restrictions(c, s, a, b):
with dask.config.set(optimization__fuse__active=False):
x = await x.persist()

assert all(not ts.worker_restrictions for ts in s.tasks.values())
assert all({"fake"} == ts.host_restrictions for ts in s.tasks.values())
assert all(
[
{"workers": ["fake"], "allow_other_workers": True} == ts.annotations
for ts in s.tasks.values()
]
)
for ts in s.tasks.values():
assert not ts.worker_restrictions
assert ts.host_restrictions == {"fake"}
assert ts.annotations["workers"] == ["fake"]
assert ts.annotations["allow_other_workers"] is True


@gen_cluster(
Expand All @@ -7079,8 +7081,9 @@ async def test_annotations_submit_map(c, s, a, b):

await wait([f, *fs])

assert all([{"foo": 1} == ts.resource_restrictions for ts in s.tasks.values()])
assert all([{"resources": {"foo": 1}} == ts.annotations for ts in s.tasks.values()])
for ts in s.tasks.values():
assert ts.resource_restrictions == {"foo": 1}
assert ts.annotations["resources"] == {"foo": 1}
assert not b.state.tasks


Expand Down
101 changes: 101 additions & 0 deletions distributed/tests/test_spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

from dask import delayed

from distributed.spans import span
from distributed.utils_test import async_poll_for, gen_cluster, inc


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_spans(c, s, a):
x = delayed(inc)(1) # Default span
with span("my workflow"):
with span("p1"):
y = x + 1

@span("p2")
def f(i):
return i * 2

z = f(y)

zp = c.persist(z)
assert await c.compute(zp) == 6

ext = s.extensions["spans"]

assert s.tasks[x.key].annotations == {}
assert s.tasks[y.key].annotations == {"span": ("my workflow", "p1")}
assert s.tasks[z.key].annotations == {"span": ("my workflow", "p2")}

assert a.state.tasks[x.key].annotations == {}
assert a.state.tasks[y.key].annotations == {"span": ("my workflow", "p1")}
assert a.state.tasks[z.key].annotations == {"span": ("my workflow", "p2")}

assert ext.spans.keys() == {
("default",),
("my workflow",),
("my workflow", "p1"),
("my workflow", "p2"),
}
for k, sp in ext.spans.items():
assert sp.id == k

default = ext.spans["default",]
mywf = ext.spans["my workflow",]
p1 = ext.spans["my workflow", "p1"]
p2 = ext.spans["my workflow", "p2"]

assert default.children == set()
assert mywf.children == {p1, p2}
assert p1.children == set()
assert p2.children == set()

assert str(default) == "Span('default',)"
assert str(p1) == "Span('my workflow', 'p1')"
assert ext.root_spans == {"default": default, "my workflow": mywf}
assert ext.spans_search_by_tag["my workflow"] == {mywf, p1, p2}

assert s.tasks[x.key].annotations == {}
assert s.tasks[y.key].annotations["span"] == ("my workflow", "p1")

# Test that spans survive their tasks
del zp
await async_poll_for(lambda: not s.tasks, timeout=5)
assert ext.spans.keys() == {
("default",),
("my workflow",),
("my workflow", "p1"),
("my workflow", "p2"),
}


@gen_cluster(client=True)
async def test_submit(c, s, a, b):
x = c.submit(inc, 1, key="x")
with span("foo"):
y = c.submit(inc, 2, key="y")
assert await x == 2
assert await y == 3

assert "span" not in s.tasks["x"].annotations
assert s.tasks["y"].annotations["span"] == ("foo",)
assert s.extensions["spans"].spans.keys() == {("default",), ("foo",)}


@gen_cluster(client=True)
async def test_multiple_tags(c, s, a, b):
with span("foo", "bar"):
x = c.submit(inc, 1, key="x")
assert await x == 2

assert s.tasks["x"].annotations["span"] == ("foo", "bar")
assert s.extensions["spans"].spans_search_by_tag.keys() == {"foo", "bar"}


@gen_cluster(client=True, scheduler_kwargs={"extensions": {}})
async def test_no_extension(c, s, a, b):
x = c.submit(inc, 1, key="x")
assert await x == 2
assert "spans" not in s.extensions
assert s.tasks["x"].annotations == {}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ allow_incomplete_defs = true
# Recent or recently overhauled modules featuring stricter validation
module = [
"distributed.active_memory_manager",
"distributed.spans",
"distributed.system_monitor",
"distributed.worker_memory",
"distributed.worker_state_machine",
Expand Down