From 49dc6be705396c575522445ba963b2c81fff6404 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 16 Jun 2023 14:17:14 +0100 Subject: [PATCH] Fix race condition in Fine Performance Metrics sync --- distributed/spans.py | 47 +++++++++++++++--------- distributed/tests/test_worker_metrics.py | 22 +++++++++++ distributed/worker.py | 15 ++++++-- 3 files changed, 64 insertions(+), 20 deletions(-) diff --git a/distributed/spans.py b/distributed/spans.py index 56757b015a9..49a8e7f5b2a 100644 --- a/distributed/spans.py +++ b/distributed/spans.py @@ -3,7 +3,7 @@ import uuid import weakref from collections import defaultdict -from collections.abc import Iterable, Iterator +from collections.abc import Hashable, Iterable, Iterator from contextlib import contextmanager from typing import TYPE_CHECKING, Any @@ -124,7 +124,7 @@ class Span: #: stop enqueued: float - _cumulative_worker_metrics: defaultdict[tuple[str, ...], float] + _cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float] # Support for weakrefs to a class with __slots__ __weakref__: Any @@ -266,7 +266,7 @@ def nbytes_total(self) -> int: return sum(tg.nbytes_total for tg in self.traverse_groups()) @property - def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]: + def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]: """Replica of Worker.digests_total and Scheduler.cumulative_worker_metrics, but only for the metrics that can be attributed to the current span tree. The span_id has been removed from the key. @@ -276,7 +276,7 @@ def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]: but more may be added in the future with a different format; please test for ``k[0] == "execute"``. """ - out: defaultdict[tuple[str, ...], float] = defaultdict(float) + out: defaultdict[tuple[Hashable, ...], float] = defaultdict(float) for child in self.traverse_spans(): for k, v in child._cumulative_worker_metrics.items(): out[k] += v @@ -418,7 +418,7 @@ def merge_by_tags(self, *tags: str) -> Span: return Span.merge(*self.find_by_tags(*tags)) def heartbeat( - self, ws: WorkerState, data: dict[str, dict[tuple[str, ...], float]] + self, ws: WorkerState, data: dict[tuple[Hashable, ...], float] ) -> None: """Triggered by SpansWorkerExtension.heartbeat(). @@ -429,36 +429,49 @@ def heartbeat( SpansWorkerExtension.heartbeat Span.cumulative_worker_metrics """ - for span_id, metrics in data.items(): + for (context, span_id, *other), v in data.items(): + assert isinstance(span_id, str) span = self.spans[span_id] - for k, v in metrics.items(): - span._cumulative_worker_metrics[k] += v + span._cumulative_worker_metrics[(context, *other)] += v class SpansWorkerExtension: """Worker extension for spans support""" worker: Worker + digests_total_since_heartbeat: dict[tuple[Hashable, ...], float] def __init__(self, worker: Worker): self.worker = worker + self.digests_total_since_heartbeat = {} - def heartbeat(self) -> dict[str, dict[tuple[str, ...], float]]: + def collect_digests(self) -> None: + """Make a local copy of Worker.digests_total_since_heartbeat. We can't just + parse it directly in heartbeat() as the event loop may be yielded between its + call and `self.worker.digests_total_since_heartbeat.clear()`, causing the + scheduler to become misaligned with the workers. + """ + # Note: this method may be called spuriously by Worker._register_with_scheduler, + # but when it does it's guaranteed not to find any metrics + assert not self.digests_total_since_heartbeat + self.digests_total_since_heartbeat = { + k: v + for k, v in self.worker.digests_total_since_heartbeat.items() + if isinstance(k, tuple) and k[0] == "execute" + } + + def heartbeat(self) -> dict[tuple[Hashable, ...], float]: """Apportion the metrics that do have a span to the Spans on the scheduler Returns ------- - ``{span_id: {("execute", prefix, activity, unit): value}}`` + ``{(context, span_id, prefix, activity, unit): value}}`` See also -------- SpansSchedulerExtension.heartbeat Span.cumulative_worker_metrics """ - out: defaultdict[str, dict[tuple[str, ...], float]] = defaultdict(dict) - for k, v in self.worker.digests_total_since_heartbeat.items(): - if isinstance(k, tuple) and k[0] == "execute": - _, span_id, prefix, activity, unit = k - assert span_id is not None - out[span_id]["execute", prefix, activity, unit] = v - return dict(out) + out = self.digests_total_since_heartbeat + self.digests_total_since_heartbeat = {} + return out diff --git a/distributed/tests/test_worker_metrics.py b/distributed/tests/test_worker_metrics.py index cfc0bbdbb90..5f2ac1b8614 100644 --- a/distributed/tests/test_worker_metrics.py +++ b/distributed/tests/test_worker_metrics.py @@ -594,3 +594,25 @@ async def test_no_spans_extension(c, s, a): if not WINDOWS: assert w_metrics[wk] > 0 assert s_metrics[sk] == w_metrics[wk] + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_new_metrics_during_heartbeat(c, s, a): + """Make sure that metrics generated during the heartbeat don't get lost""" + # Create default span + await c.submit(inc, 1) + span = s.extensions["spans"].spans_search_by_name["default",][0] + + hb_task = asyncio.create_task(a.heartbeat()) + n = 0 + while not hb_task.done(): + n += 1 + a.digest_metric(("execute", span.id, "x", "test", "test"), 1) + await asyncio.sleep(0) + await hb_task + assert n > 10 + await a.heartbeat() + + assert a.digests_total["execute", span.id, "x", "test", "test"] == n + assert s.cumulative_worker_metrics["execute", "x", "test", "test"] == n + assert span.cumulative_worker_metrics["execute", "x", "test", "test"] == n diff --git a/distributed/worker.py b/distributed/worker.py index 9e7bce3bd5d..9e2e94b2b63 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1034,14 +1034,24 @@ async def get_metrics(self) -> dict: # spilling is disabled spilled_memory, spilled_disk = 0, 0 - # Squash span_id in metrics. - # SpansWorkerExtension, if loaded, will send them out disaggregated. + # Send Fine Performance Metrics + # Make sure we do not yield the event loop between the moment we parse + # self.digests_total_since_heartbeat to send it to the scheduler and the moment + # we clear it! + spans_ext: SpansWorkerExtension | None = self.extensions.get("spans") + if spans_ext: + # Send metrics with disaggregated span_id + spans_ext.collect_digests() + + # Send metrics with squashed span_id digests: defaultdict[Hashable, float] = defaultdict(float) for k, v in self.digests_total_since_heartbeat.items(): if isinstance(k, tuple) and k[0] == "execute": k = k[:1] + k[2:] digests[k] += v + self.digests_total_since_heartbeat.clear() + out = dict( task_counts=self.state.task_counter.current_count(by_prefix=False), bandwidth={ @@ -1259,7 +1269,6 @@ async def heartbeat(self) -> None: if hasattr(extension, "heartbeat") }, ) - self.digests_total_since_heartbeat.clear() end = time() middle = (start + end) / 2