Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spans: refactor sums of mappings #7918

Merged
merged 1 commit into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator, MutableSet
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableSet
from typing import Any, TypeVar, cast

T = TypeVar("T", bound=Hashable)
Expand Down Expand Up @@ -198,3 +198,17 @@ def clear(self) -> None:
self._data.clear()
self._heap.clear()
self._sorted = True


def sum_mappings(ds: Iterable[Mapping[K, V] | Iterable[tuple[K, V]]], /) -> dict[K, V]:
"""Sum the values of the given mappings, key by key"""
out: dict[K, V] = {}
for d in ds:
if isinstance(d, Mapping):
d = d.items()
for k, v in d:
try:
out[k] += v # type: ignore
except KeyError:
out[k] = v
return out
28 changes: 10 additions & 18 deletions distributed/spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import dask.config

from distributed.collections import sum_mappings
from distributed.metrics import time

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,19 +200,15 @@ def stop(self) -> float:
return max(tg.stop for tg in self.traverse_groups())

@property
def states(self) -> defaultdict[TaskStateState, int]:
def states(self) -> dict[TaskStateState, int]:
"""The number of tasks currently in each state in this span tree;
e.g. ``{"memory": 10, "processing": 3, "released": 4, ...}``.

See also
--------
distributed.scheduler.TaskGroup.states
"""
out: defaultdict[TaskStateState, int] = defaultdict(int)
for tg in self.traverse_groups():
for state, count in tg.states.items():
out[state] += count
return out
return sum_mappings(tg.states for tg in self.traverse_groups())

@property
def done(self) -> bool:
Expand All @@ -230,19 +227,15 @@ def done(self) -> bool:
return all(tg.done for tg in self.traverse_groups())

@property
def all_durations(self) -> defaultdict[str, float]:
def all_durations(self) -> dict[str, float]:
"""Cumulative duration of all completed actions in this span tree, by action

See also
--------
duration
distributed.scheduler.TaskGroup.all_durations
"""
out: defaultdict[str, float] = defaultdict(float)
for tg in self.traverse_groups():
for action, nsec in tg.all_durations.items():
out[action] += nsec
return out
return sum_mappings(tg.all_durations for tg in self.traverse_groups())

@property
def duration(self) -> float:
Expand All @@ -266,7 +259,7 @@ def nbytes_total(self) -> int:
return sum(tg.nbytes_total for tg in self.traverse_groups())

@property
def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]:
def cumulative_worker_metrics(self) -> dict[tuple[Hashable, ...], float]:
"""Replica of Worker.digests_total and Scheduler.cumulative_worker_metrics, but
only for the metrics that can be attributed to the current span tree.
The span_id has been removed from the key.
Expand All @@ -276,11 +269,9 @@ def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]:
but more may be added in the future with a different format; please test for
``k[0] == "execute"``.
"""
out: defaultdict[tuple[Hashable, ...], float] = defaultdict(float)
for child in self.traverse_spans():
for k, v in child._cumulative_worker_metrics.items():
out[k] += v
return out
return sum_mappings(
child._cumulative_worker_metrics for child in self.traverse_spans()
)

@staticmethod
def merge(*items: Span) -> Span:
Expand Down Expand Up @@ -471,6 +462,7 @@ def heartbeat(self) -> dict[tuple[Hashable, ...], float]:
--------
SpansSchedulerExtension.heartbeat
Span.cumulative_worker_metrics
distributed.worker.Worker.get_metrics
"""
out = self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {}
Expand Down
32 changes: 31 additions & 1 deletion distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import operator
import pickle
import random
from collections.abc import Mapping

import pytest

from distributed.collections import LRU, HeapSet
from distributed.collections import LRU, HeapSet, sum_mappings


def test_lru():
Expand Down Expand Up @@ -345,3 +346,32 @@ def test_heapset_sort_duplicate():
heap.add(c1)

assert list(heap.sorted()) == [c1, c2]


class ReadOnlyMapping(Mapping):
def __init__(self, d: Mapping):
self.d = d

def __getitem__(self, item):
return self.d[item]

def __iter__(self):
return iter(self.d)

def __len__(self):
return len(self.d)


def test_sum_mappings():
a = {"x": 1, "y": 1.2, "z": [3, 4]}
b = ReadOnlyMapping({"w": 7, "y": 3.4, "z": [5, 6]})
c = iter([("y", 0.2), ("y", -0.5)])
actual = sum_mappings(iter([a, b, c]))
assert isinstance(actual, dict)
assert actual == {"x": 1, "y": 4.3, "z": [3, 4, 5, 6], "w": 7}
assert isinstance(actual["x"], int) # Not 1.0
assert list(actual) == ["x", "y", "z", "w"]

d = {"x0": 1, "x1": 2, "y0": 4}
actual = sum_mappings([((k[0], v) for k, v in d.items())])
assert actual == {"x": 3, "y": 4}