Skip to content

Commit

Permalink
Refactor event logging functionality into broker (#8731)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Jul 1, 2024
1 parent f997f21 commit 63e5108
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 140 deletions.
106 changes: 106 additions & 0 deletions distributed/broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from __future__ import annotations

import logging
from collections import defaultdict, deque
from collections.abc import Collection
from functools import partial
from typing import TYPE_CHECKING, Any, overload

from distributed.metrics import time

if TYPE_CHECKING:
from distributed import Scheduler

logger = logging.getLogger(__name__)


class Topic:
events: deque
count: int
subscribers: set

def __init__(self, maxlen: int):
self.events = deque(maxlen=maxlen)
self.count = 0
self.subscribers = set()

def subscribe(self, subscriber: str) -> None:
self.subscribers.add(subscriber)

def unsubscribe(self, subscriber: str) -> None:
self.subscribers.discard(subscriber)

def publish(self, event: Any) -> None:
self.events.append(event)
self.count += 1

def truncate(self) -> None:
self.events.clear()


class Broker:
_scheduler: Scheduler
_topics: defaultdict[str, Topic]

def __init__(self, maxlen: int, scheduler: Scheduler) -> None:
self._scheduler = scheduler
self._topics = defaultdict(partial(Topic, maxlen=maxlen))

def subscribe(self, topic: str, subscriber: str) -> None:
self._topics[topic].subscribe(subscriber)

def unsubscribe(self, topic: str, subscriber: str) -> None:
self._topics[topic].unsubscribe(subscriber)

def publish(self, topics: str | Collection[str], msg: Any) -> None:
event = (time(), msg)
if isinstance(topics, str):
topics = [topics]
for name in topics:
topic = self._topics[name]
topic.publish(event)
self._send_to_subscribers(name, event)

for plugin in list(self._scheduler.plugins.values()):
try:
plugin.log_event(name, msg)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)

def truncate(self, topic: str | None = None) -> None:
if topic is None:
for _topic in self._topics.values():
_topic.truncate()
elif topic in self._topics:
self._topics[topic].truncate()

def _send_to_subscribers(self, topic: str, event: Any) -> None:
msg = {
"op": "event",
"topic": topic,
"event": event,
}
client_msgs = {client: [msg] for client in self._topics[topic].subscribers}
self._scheduler.send_all(client_msgs, worker_msgs={})

@overload
def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]:
...

@overload
def get_events(
self, topic: None = None
) -> dict[str, tuple[tuple[float, Any], ...]]:
...

def get_events(
self, topic: str | None = None
) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]:
if topic is not None:
return tuple(self._topics[topic].events)
else:
return {
name: tuple(topic.events)
for name, topic in self._topics.items()
if topic.events
}
19 changes: 10 additions & 9 deletions distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,12 +1971,12 @@ def convert(self, msgs):
@without_property_validation
@log_errors
def update(self):
log = self.scheduler.get_events(topic="stealing")
current = len(self.scheduler.events["stealing"])
n = current - self.last

log = [log[-i][1][1] for i in range(1, n + 1) if log[-i][1][0] == "request"]
self.last = current
topic = self.scheduler._broker._topics["stealing"]
log = log = topic.events
n = min(topic.count - self.last, len(log))
if log:
log = [log[-i][1][1] for i in range(1, n + 1) if log[-i][1][0] == "request"]
self.last = topic.count

if log:
new = pipe(
Expand Down Expand Up @@ -2041,11 +2041,12 @@ def __init__(self, scheduler, name, height=150, **kwargs):
@without_property_validation
@log_errors
def update(self):
log = self.scheduler.events[self.name]
n = self.scheduler.event_counts[self.name] - self.last
topic = self.scheduler._broker._topics[self.name]
log = topic.events
n = min(topic.count - self.last, len(log))
if log:
log = [log[-i] for i in range(1, n + 1)]
self.last = self.scheduler.event_counts[self.name]
self.last = topic.count

if log:
actions = []
Expand Down
43 changes: 1 addition & 42 deletions distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from distributed import Nanny, Scheduler, SchedulerPlugin, Worker, get_worker
from distributed import Nanny, Scheduler, SchedulerPlugin, Worker
from distributed.protocol.pickle import dumps
from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc

Expand Down Expand Up @@ -435,47 +435,6 @@ class Plugin(SchedulerPlugin):
await c.unregister_scheduler_plugin(name="plugin")


@gen_cluster(client=True)
async def test_log_event_plugin(c, s, a, b):
class EventPlugin(SchedulerPlugin):
async def start(self, scheduler: Scheduler) -> None:
self.scheduler = scheduler
self.scheduler._recorded_events = list() # type: ignore

def log_event(self, topic, msg):
self.scheduler._recorded_events.append((topic, msg))

await c.register_plugin(EventPlugin())

def f():
get_worker().log_event("foo", 123)

await c.submit(f)

assert ("foo", 123) in s._recorded_events


@gen_cluster(client=True)
async def test_log_event_plugin_multiple_topics(c, s, a, b):
class EventPlugin(SchedulerPlugin):
async def start(self, scheduler: Scheduler) -> None:
self.scheduler = scheduler
self.scheduler._recorded_events = list() # type: ignore

def log_event(self, topic, msg):
self.scheduler._recorded_events.append((topic, msg))

await c.register_plugin(EventPlugin())

def f():
get_worker().log_event(["foo", "bar"], 123)

await c.submit(f)

assert ("foo", 123) in s._recorded_events
assert ("bar", 123) in s._recorded_events


@gen_cluster(client=True)
async def test_register_plugin_on_scheduler(c, s, a, b):
class MyPlugin(SchedulerPlugin):
Expand Down
4 changes: 2 additions & 2 deletions distributed/http/scheduler/tests/test_stealing_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def fetch_metrics_by_cost_multipliers():
count = sum(active_metrics.values())
assert count > 0
expected_count = sum(
len(event[1]) for _, event in s.events["stealing"] if event[0] == "request"
len(event[1]) for _, event in s.get_events("stealing") if event[0] == "request"
)
assert count == expected_count

Expand Down Expand Up @@ -87,7 +87,7 @@ async def fetch_metrics_by_cost_multipliers():
assert count > 0
expected_cost = sum(
request[3]
for _, event in s.events["stealing"]
for _, event in s.get_events("stealing")
for request in event[1]
if event[0] == "request"
)
Expand Down
59 changes: 22 additions & 37 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
from distributed.broker import Broker
from distributed.client import SourceCode
from distributed.collections import HeapSet
from distributed.comm import (
Expand Down Expand Up @@ -3804,9 +3805,7 @@ async def post(self):
]

maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.events = defaultdict(partial(deque, maxlen=maxlen))
self.event_counts = defaultdict(int)
self.event_subscriber = defaultdict(set)
self._broker = Broker(maxlen, self)
self.worker_plugins = {}
self.nanny_plugins = {}
self._starting_nannies = set()
Expand Down Expand Up @@ -4002,7 +4001,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
"workers": self.workers,
"clients": self.clients,
"memory": self.memory,
"events": self.events,
"events": self._broker._topics,
"extensions": self.extensions,
}
extra = {k: v for k, v in extra.items() if k not in exclude}
Expand Down Expand Up @@ -5406,8 +5405,8 @@ async def remove_worker(

async def remove_worker_from_events() -> None:
# If the worker isn't registered anymore after the delay, remove from events
if address not in self.workers and address in self.events:
del self.events[address]
if address not in self.workers:
self._broker.truncate(address)

cleanup_delay = parse_timedelta(
dask.config.get("distributed.scheduler.events-cleanup-delay")
Expand Down Expand Up @@ -5820,8 +5819,8 @@ def remove_client(self, client: str, stimulus_id: str | None = None) -> None:

async def remove_client_from_events() -> None:
# If the client isn't registered anymore after the delay, remove from events
if client not in self.clients and client in self.events:
del self.events[client]
if client not in self.clients:
self._broker.truncate(client)

cleanup_delay = parse_timedelta(
dask.config.get("distributed.scheduler.events-cleanup-delay")
Expand Down Expand Up @@ -8423,40 +8422,26 @@ def log_event(self, topic: str | Collection[str], msg: Any) -> None:
--------
Client.log_event
"""
event = (time(), msg)
if isinstance(topic, str):
topic = [topic]
for t in topic:
self.events[t].append(event)
self.event_counts[t] += 1
self._report_event(t, event)
self._broker.publish(topic, msg)

for plugin in list(self.plugins.values()):
try:
plugin.log_event(t, msg)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)
def subscribe_topic(self, topic: str, client: str) -> None:
self._broker.subscribe(topic, client)

def _report_event(self, name, event):
msg = {
"op": "event",
"topic": name,
"event": event,
}
client_msgs = {client: [msg] for client in self.event_subscriber[name]}
self.send_all(client_msgs, worker_msgs={})
def unsubscribe_topic(self, topic: str, client: str) -> None:
self._broker.unsubscribe(topic, client)

def subscribe_topic(self, topic, client):
self.event_subscriber[topic].add(client)
@overload
def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]:
...

def unsubscribe_topic(self, topic, client):
self.event_subscriber[topic].discard(client)
@overload
def get_events(self) -> dict[str, tuple[tuple[float, Any], ...]]:
...

def get_events(self, topic=None):
if topic is not None:
return tuple(self.events[topic])
else:
return valmap(tuple, self.events)
def get_events(
self, topic: str | None = None
) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]:
return self._broker.get_events(topic)

async def get_worker_monitor_info(self, recent=False, starts=None):
if starts is None:
Expand Down
8 changes: 4 additions & 4 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b):

with pytest.raises(KilledWorker):
await out
assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1
assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1

await c.close()
await check_worker_cleanup(a)
Expand All @@ -460,7 +460,7 @@ async def test_restarting_does_not_log_p2p_failed(c, s, a, b):
await b.close()

await out
assert not s.events["p2p"]
assert not s.get_events("p2p")
await c.close()
await check_worker_cleanup(a)
await check_worker_cleanup(b, closed=True)
Expand Down Expand Up @@ -831,7 +831,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b):

with pytest.raises(KilledWorker):
await out
assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1
assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1

alive_shuffle.block_inputs_done.set()

Expand Down Expand Up @@ -994,7 +994,7 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b):

with pytest.raises(KilledWorker):
await out
assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1
assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1

await c.close()
await check_worker_cleanup(a)
Expand Down
4 changes: 1 addition & 3 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import logging
from collections import defaultdict, deque
from collections import defaultdict
from collections.abc import Container
from functools import partial
from math import log2
Expand Down Expand Up @@ -106,8 +106,6 @@ def __init__(self, scheduler: Scheduler):
)
# `callback_time` is in milliseconds
self.scheduler.add_plugin(self)
maxlen = dask.config.get("distributed.admin.low-level-log-length")
self.scheduler.events["stealing"] = deque(maxlen=maxlen)
self.count = 0
self.in_flight = {}
self.in_flight_occupancy = defaultdict(int)
Expand Down
4 changes: 3 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6461,7 +6461,9 @@ def test_direct_to_workers(s, loop):
with Client(s["address"], loop=loop, direct_to_workers=True) as client:
future = client.scatter(1)
future.result()
resp = client.run_on_scheduler(lambda dask_scheduler: dask_scheduler.events)
resp = client.run_on_scheduler(
lambda dask_scheduler: dask_scheduler.get_events()
)
assert "gather" not in str(resp)


Expand Down
Loading

0 comments on commit 63e5108

Please sign in to comment.