Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in Fine Performance Metrics sync #7927

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions distributed/spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
import weakref
from collections import defaultdict
from collections.abc import Iterable, Iterator
from collections.abc import Hashable, Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -124,7 +124,7 @@ class Span:
#: stop
enqueued: float

_cumulative_worker_metrics: defaultdict[tuple[str, ...], float]
_cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float]

# Support for weakrefs to a class with __slots__
__weakref__: Any
Expand Down Expand Up @@ -266,7 +266,7 @@ def nbytes_total(self) -> int:
return sum(tg.nbytes_total for tg in self.traverse_groups())

@property
def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]:
def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]:
"""Replica of Worker.digests_total and Scheduler.cumulative_worker_metrics, but
only for the metrics that can be attributed to the current span tree.
The span_id has been removed from the key.
Expand All @@ -276,7 +276,7 @@ def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]:
but more may be added in the future with a different format; please test for
``k[0] == "execute"``.
"""
out: defaultdict[tuple[str, ...], float] = defaultdict(float)
out: defaultdict[tuple[Hashable, ...], float] = defaultdict(float)
for child in self.traverse_spans():
for k, v in child._cumulative_worker_metrics.items():
out[k] += v
Expand Down Expand Up @@ -418,7 +418,7 @@ def merge_by_tags(self, *tags: str) -> Span:
return Span.merge(*self.find_by_tags(*tags))

def heartbeat(
self, ws: WorkerState, data: dict[str, dict[tuple[str, ...], float]]
self, ws: WorkerState, data: dict[tuple[Hashable, ...], float]
) -> None:
"""Triggered by SpansWorkerExtension.heartbeat().

Expand All @@ -429,36 +429,49 @@ def heartbeat(
SpansWorkerExtension.heartbeat
Span.cumulative_worker_metrics
"""
for span_id, metrics in data.items():
for (context, span_id, *other), v in data.items():
assert isinstance(span_id, str)
span = self.spans[span_id]
for k, v in metrics.items():
span._cumulative_worker_metrics[k] += v
span._cumulative_worker_metrics[(context, *other)] += v


class SpansWorkerExtension:
"""Worker extension for spans support"""

worker: Worker
digests_total_since_heartbeat: dict[tuple[Hashable, ...], float]

def __init__(self, worker: Worker):
self.worker = worker
self.digests_total_since_heartbeat = {}

def heartbeat(self) -> dict[str, dict[tuple[str, ...], float]]:
def collect_digests(self) -> None:
"""Make a local copy of Worker.digests_total_since_heartbeat. We can't just
parse it directly in heartbeat() as the event loop may be yielded between its
call and `self.worker.digests_total_since_heartbeat.clear()`, causing the
scheduler to become misaligned with the workers.
"""
# Note: this method may be called spuriously by Worker._register_with_scheduler,
# but when it does it's guaranteed not to find any metrics
assert not self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {
k: v
for k, v in self.worker.digests_total_since_heartbeat.items()
if isinstance(k, tuple) and k[0] == "execute"
}

def heartbeat(self) -> dict[tuple[Hashable, ...], float]:
"""Apportion the metrics that do have a span to the Spans on the scheduler

Returns
-------
``{span_id: {("execute", prefix, activity, unit): value}}``
``{(context, span_id, prefix, activity, unit): value}}``

See also
--------
SpansSchedulerExtension.heartbeat
Span.cumulative_worker_metrics
"""
out: defaultdict[str, dict[tuple[str, ...], float]] = defaultdict(dict)
for k, v in self.worker.digests_total_since_heartbeat.items():
if isinstance(k, tuple) and k[0] == "execute":
_, span_id, prefix, activity, unit = k
assert span_id is not None
out[span_id]["execute", prefix, activity, unit] = v
return dict(out)
out = self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {}
return out
22 changes: 22 additions & 0 deletions distributed/tests/test_worker_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,25 @@ async def test_no_spans_extension(c, s, a):
if not WINDOWS:
assert w_metrics[wk] > 0
assert s_metrics[sk] == w_metrics[wk]


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_new_metrics_during_heartbeat(c, s, a):
"""Make sure that metrics generated during the heartbeat don't get lost"""
# Create default span
await c.submit(inc, 1)
span = s.extensions["spans"].spans_search_by_name["default",][0]

hb_task = asyncio.create_task(a.heartbeat())
n = 0
while not hb_task.done():
n += 1
a.digest_metric(("execute", span.id, "x", "test", "test"), 1)
await asyncio.sleep(0)
await hb_task
assert n > 10
await a.heartbeat()

assert a.digests_total["execute", span.id, "x", "test", "test"] == n
assert s.cumulative_worker_metrics["execute", "x", "test", "test"] == n
assert span.cumulative_worker_metrics["execute", "x", "test", "test"] == n
15 changes: 12 additions & 3 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,14 +1034,24 @@ async def get_metrics(self) -> dict:
# spilling is disabled
spilled_memory, spilled_disk = 0, 0

# Squash span_id in metrics.
# SpansWorkerExtension, if loaded, will send them out disaggregated.
# Send Fine Performance Metrics
# Make sure we do not yield the event loop between the moment we parse
# self.digests_total_since_heartbeat to send it to the scheduler and the moment
# we clear it!
spans_ext: SpansWorkerExtension | None = self.extensions.get("spans")
if spans_ext:
# Send metrics with disaggregated span_id
spans_ext.collect_digests()

# Send metrics with squashed span_id
digests: defaultdict[Hashable, float] = defaultdict(float)
for k, v in self.digests_total_since_heartbeat.items():
if isinstance(k, tuple) and k[0] == "execute":
k = k[:1] + k[2:]
digests[k] += v

self.digests_total_since_heartbeat.clear()

out = dict(
task_counts=self.state.task_counter.current_count(by_prefix=False),
bandwidth={
Expand Down Expand Up @@ -1259,7 +1269,6 @@ async def heartbeat(self) -> None:
if hasattr(extension, "heartbeat")
},
)
self.digests_total_since_heartbeat.clear()

end = time()
middle = (start + end) / 2
Expand Down