diff --git a/distributed/collections.py b/distributed/collections.py index ebe26053858..34a0ff097f9 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -4,7 +4,7 @@ import itertools import weakref from collections import OrderedDict, UserDict -from collections.abc import Callable, Hashable, Iterator, MutableSet +from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableSet from typing import Any, TypeVar, cast T = TypeVar("T", bound=Hashable) @@ -198,3 +198,17 @@ def clear(self) -> None: self._data.clear() self._heap.clear() self._sorted = True + + +def sum_mappings(ds: Iterable[Mapping[K, V] | Iterable[tuple[K, V]]], /) -> dict[K, V]: + """Sum the values of the given mappings, key by key""" + out: dict[K, V] = {} + for d in ds: + if isinstance(d, Mapping): + d = d.items() + for k, v in d: + try: + out[k] += v # type: ignore + except KeyError: + out[k] = v + return out diff --git a/distributed/spans.py b/distributed/spans.py index 49a8e7f5b2a..d85d52b1533 100644 --- a/distributed/spans.py +++ b/distributed/spans.py @@ -9,6 +9,7 @@ import dask.config +from distributed.collections import sum_mappings from distributed.metrics import time if TYPE_CHECKING: @@ -199,7 +200,7 @@ def stop(self) -> float: return max(tg.stop for tg in self.traverse_groups()) @property - def states(self) -> defaultdict[TaskStateState, int]: + def states(self) -> dict[TaskStateState, int]: """The number of tasks currently in each state in this span tree; e.g. ``{"memory": 10, "processing": 3, "released": 4, ...}``. @@ -207,11 +208,7 @@ def states(self) -> defaultdict[TaskStateState, int]: -------- distributed.scheduler.TaskGroup.states """ - out: defaultdict[TaskStateState, int] = defaultdict(int) - for tg in self.traverse_groups(): - for state, count in tg.states.items(): - out[state] += count - return out + return sum_mappings(tg.states for tg in self.traverse_groups()) @property def done(self) -> bool: @@ -230,7 +227,7 @@ def done(self) -> bool: return all(tg.done for tg in self.traverse_groups()) @property - def all_durations(self) -> defaultdict[str, float]: + def all_durations(self) -> dict[str, float]: """Cumulative duration of all completed actions in this span tree, by action See also @@ -238,11 +235,7 @@ def all_durations(self) -> defaultdict[str, float]: duration distributed.scheduler.TaskGroup.all_durations """ - out: defaultdict[str, float] = defaultdict(float) - for tg in self.traverse_groups(): - for action, nsec in tg.all_durations.items(): - out[action] += nsec - return out + return sum_mappings(tg.all_durations for tg in self.traverse_groups()) @property def duration(self) -> float: @@ -266,7 +259,7 @@ def nbytes_total(self) -> int: return sum(tg.nbytes_total for tg in self.traverse_groups()) @property - def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]: + def cumulative_worker_metrics(self) -> dict[tuple[Hashable, ...], float]: """Replica of Worker.digests_total and Scheduler.cumulative_worker_metrics, but only for the metrics that can be attributed to the current span tree. The span_id has been removed from the key. @@ -276,11 +269,9 @@ def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]: but more may be added in the future with a different format; please test for ``k[0] == "execute"``. """ - out: defaultdict[tuple[Hashable, ...], float] = defaultdict(float) - for child in self.traverse_spans(): - for k, v in child._cumulative_worker_metrics.items(): - out[k] += v - return out + return sum_mappings( + child._cumulative_worker_metrics for child in self.traverse_spans() + ) @staticmethod def merge(*items: Span) -> Span: @@ -471,6 +462,7 @@ def heartbeat(self) -> dict[tuple[Hashable, ...], float]: -------- SpansSchedulerExtension.heartbeat Span.cumulative_worker_metrics + distributed.worker.Worker.get_metrics """ out = self.digests_total_since_heartbeat self.digests_total_since_heartbeat = {} diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 6db1811072d..ce3c3778734 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -4,10 +4,11 @@ import operator import pickle import random +from collections.abc import Mapping import pytest -from distributed.collections import LRU, HeapSet +from distributed.collections import LRU, HeapSet, sum_mappings def test_lru(): @@ -345,3 +346,32 @@ def test_heapset_sort_duplicate(): heap.add(c1) assert list(heap.sorted()) == [c1, c2] + + +class ReadOnlyMapping(Mapping): + def __init__(self, d: Mapping): + self.d = d + + def __getitem__(self, item): + return self.d[item] + + def __iter__(self): + return iter(self.d) + + def __len__(self): + return len(self.d) + + +def test_sum_mappings(): + a = {"x": 1, "y": 1.2, "z": [3, 4]} + b = ReadOnlyMapping({"w": 7, "y": 3.4, "z": [5, 6]}) + c = iter([("y", 0.2), ("y", -0.5)]) + actual = sum_mappings(iter([a, b, c])) + assert isinstance(actual, dict) + assert actual == {"x": 1, "y": 4.3, "z": [3, 4, 5, 6], "w": 7} + assert isinstance(actual["x"], int) # Not 1.0 + assert list(actual) == ["x", "y", "z", "w"] + + d = {"x0": 1, "x1": 2, "y0": 4} + actual = sum_mappings([((k[0], v) for k, v in d.items())]) + assert actual == {"x": 3, "y": 4}