diff --git a/distributed/broker.py b/distributed/broker.py new file mode 100644 index 00000000000..298225eb85f --- /dev/null +++ b/distributed/broker.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import logging +from collections import defaultdict, deque +from collections.abc import Collection +from functools import partial +from typing import TYPE_CHECKING, Any, overload + +from distributed.metrics import time + +if TYPE_CHECKING: + from distributed import Scheduler + +logger = logging.getLogger(__name__) + + +class Topic: + events: deque + count: int + subscribers: set + + def __init__(self, maxlen: int): + self.events = deque(maxlen=maxlen) + self.count = 0 + self.subscribers = set() + + def subscribe(self, subscriber: str) -> None: + self.subscribers.add(subscriber) + + def unsubscribe(self, subscriber: str) -> None: + self.subscribers.discard(subscriber) + + def publish(self, event: Any) -> None: + self.events.append(event) + self.count += 1 + + def truncate(self) -> None: + self.events.clear() + + +class Broker: + _scheduler: Scheduler + _topics: defaultdict[str, Topic] + + def __init__(self, maxlen: int, scheduler: Scheduler) -> None: + self._scheduler = scheduler + self._topics = defaultdict(partial(Topic, maxlen=maxlen)) + + def subscribe(self, topic: str, subscriber: str) -> None: + self._topics[topic].subscribe(subscriber) + + def unsubscribe(self, topic: str, subscriber: str) -> None: + self._topics[topic].unsubscribe(subscriber) + + def publish(self, topics: str | Collection[str], msg: Any) -> None: + event = (time(), msg) + if isinstance(topics, str): + topics = [topics] + for name in topics: + topic = self._topics[name] + topic.publish(event) + self._send_to_subscribers(name, event) + + for plugin in list(self._scheduler.plugins.values()): + try: + plugin.log_event(name, msg) + except Exception: + logger.info("Plugin failed with exception", exc_info=True) + + def truncate(self, topic: str | None = None) -> None: + if topic is None: + for _topic in self._topics.values(): + _topic.truncate() + elif topic in self._topics: + self._topics[topic].truncate() + + def _send_to_subscribers(self, topic: str, event: Any) -> None: + msg = { + "op": "event", + "topic": topic, + "event": event, + } + client_msgs = {client: [msg] for client in self._topics[topic].subscribers} + self._scheduler.send_all(client_msgs, worker_msgs={}) + + @overload + def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: + ... + + @overload + def get_events( + self, topic: None = None + ) -> dict[str, tuple[tuple[float, Any], ...]]: + ... + + def get_events( + self, topic: str | None = None + ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: + if topic is not None: + return tuple(self._topics[topic].events) + else: + return { + name: tuple(topic.events) + for name, topic in self._topics.items() + if topic.events + } diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 7f5b391ca00..982d234c9ab 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1971,12 +1971,12 @@ def convert(self, msgs): @without_property_validation @log_errors def update(self): - log = self.scheduler.get_events(topic="stealing") - current = len(self.scheduler.events["stealing"]) - n = current - self.last - - log = [log[-i][1][1] for i in range(1, n + 1) if log[-i][1][0] == "request"] - self.last = current + topic = self.scheduler._broker._topics["stealing"] + log = log = topic.events + n = min(topic.count - self.last, len(log)) + if log: + log = [log[-i][1][1] for i in range(1, n + 1) if log[-i][1][0] == "request"] + self.last = topic.count if log: new = pipe( @@ -2041,11 +2041,12 @@ def __init__(self, scheduler, name, height=150, **kwargs): @without_property_validation @log_errors def update(self): - log = self.scheduler.events[self.name] - n = self.scheduler.event_counts[self.name] - self.last + topic = self.scheduler._broker._topics[self.name] + log = topic.events + n = min(topic.count - self.last, len(log)) if log: log = [log[-i] for i in range(1, n + 1)] - self.last = self.scheduler.event_counts[self.name] + self.last = topic.count if log: actions = [] diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index ee8c7f501c0..c49bd965104 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -4,7 +4,7 @@ import pytest -from distributed import Nanny, Scheduler, SchedulerPlugin, Worker, get_worker +from distributed import Nanny, Scheduler, SchedulerPlugin, Worker from distributed.protocol.pickle import dumps from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc @@ -435,47 +435,6 @@ class Plugin(SchedulerPlugin): await c.unregister_scheduler_plugin(name="plugin") -@gen_cluster(client=True) -async def test_log_event_plugin(c, s, a, b): - class EventPlugin(SchedulerPlugin): - async def start(self, scheduler: Scheduler) -> None: - self.scheduler = scheduler - self.scheduler._recorded_events = list() # type: ignore - - def log_event(self, topic, msg): - self.scheduler._recorded_events.append((topic, msg)) - - await c.register_plugin(EventPlugin()) - - def f(): - get_worker().log_event("foo", 123) - - await c.submit(f) - - assert ("foo", 123) in s._recorded_events - - -@gen_cluster(client=True) -async def test_log_event_plugin_multiple_topics(c, s, a, b): - class EventPlugin(SchedulerPlugin): - async def start(self, scheduler: Scheduler) -> None: - self.scheduler = scheduler - self.scheduler._recorded_events = list() # type: ignore - - def log_event(self, topic, msg): - self.scheduler._recorded_events.append((topic, msg)) - - await c.register_plugin(EventPlugin()) - - def f(): - get_worker().log_event(["foo", "bar"], 123) - - await c.submit(f) - - assert ("foo", 123) in s._recorded_events - assert ("bar", 123) in s._recorded_events - - @gen_cluster(client=True) async def test_register_plugin_on_scheduler(c, s, a, b): class MyPlugin(SchedulerPlugin): diff --git a/distributed/http/scheduler/tests/test_stealing_http.py b/distributed/http/scheduler/tests/test_stealing_http.py index 12720763748..efd5589612c 100644 --- a/distributed/http/scheduler/tests/test_stealing_http.py +++ b/distributed/http/scheduler/tests/test_stealing_http.py @@ -53,7 +53,7 @@ async def fetch_metrics_by_cost_multipliers(): count = sum(active_metrics.values()) assert count > 0 expected_count = sum( - len(event[1]) for _, event in s.events["stealing"] if event[0] == "request" + len(event[1]) for _, event in s.get_events("stealing") if event[0] == "request" ) assert count == expected_count @@ -87,7 +87,7 @@ async def fetch_metrics_by_cost_multipliers(): assert count > 0 expected_cost = sum( request[3] - for _, event in s.events["stealing"] + for _, event in s.get_events("stealing") for request in event[1] if event[0] == "request" ) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 69ca80826e2..0273d333da3 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -75,6 +75,7 @@ from distributed._stories import scheduler_story from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from distributed.batched import BatchedSend +from distributed.broker import Broker from distributed.client import SourceCode from distributed.collections import HeapSet from distributed.comm import ( @@ -3804,9 +3805,7 @@ async def post(self): ] maxlen = dask.config.get("distributed.admin.low-level-log-length") - self.events = defaultdict(partial(deque, maxlen=maxlen)) - self.event_counts = defaultdict(int) - self.event_subscriber = defaultdict(set) + self._broker = Broker(maxlen, self) self.worker_plugins = {} self.nanny_plugins = {} self._starting_nannies = set() @@ -4002,7 +4001,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: "workers": self.workers, "clients": self.clients, "memory": self.memory, - "events": self.events, + "events": self._broker._topics, "extensions": self.extensions, } extra = {k: v for k, v in extra.items() if k not in exclude} @@ -5406,8 +5405,8 @@ async def remove_worker( async def remove_worker_from_events() -> None: # If the worker isn't registered anymore after the delay, remove from events - if address not in self.workers and address in self.events: - del self.events[address] + if address not in self.workers: + self._broker.truncate(address) cleanup_delay = parse_timedelta( dask.config.get("distributed.scheduler.events-cleanup-delay") @@ -5820,8 +5819,8 @@ def remove_client(self, client: str, stimulus_id: str | None = None) -> None: async def remove_client_from_events() -> None: # If the client isn't registered anymore after the delay, remove from events - if client not in self.clients and client in self.events: - del self.events[client] + if client not in self.clients: + self._broker.truncate(client) cleanup_delay = parse_timedelta( dask.config.get("distributed.scheduler.events-cleanup-delay") @@ -8423,40 +8422,26 @@ def log_event(self, topic: str | Collection[str], msg: Any) -> None: -------- Client.log_event """ - event = (time(), msg) - if isinstance(topic, str): - topic = [topic] - for t in topic: - self.events[t].append(event) - self.event_counts[t] += 1 - self._report_event(t, event) + self._broker.publish(topic, msg) - for plugin in list(self.plugins.values()): - try: - plugin.log_event(t, msg) - except Exception: - logger.info("Plugin failed with exception", exc_info=True) + def subscribe_topic(self, topic: str, client: str) -> None: + self._broker.subscribe(topic, client) - def _report_event(self, name, event): - msg = { - "op": "event", - "topic": name, - "event": event, - } - client_msgs = {client: [msg] for client in self.event_subscriber[name]} - self.send_all(client_msgs, worker_msgs={}) + def unsubscribe_topic(self, topic: str, client: str) -> None: + self._broker.unsubscribe(topic, client) - def subscribe_topic(self, topic, client): - self.event_subscriber[topic].add(client) + @overload + def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: + ... - def unsubscribe_topic(self, topic, client): - self.event_subscriber[topic].discard(client) + @overload + def get_events(self) -> dict[str, tuple[tuple[float, Any], ...]]: + ... - def get_events(self, topic=None): - if topic is not None: - return tuple(self.events[topic]) - else: - return valmap(tuple, self.events) + def get_events( + self, topic: str | None = None + ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: + return self._broker.get_events(topic) async def get_worker_monitor_info(self, recent=False, starts=None): if starts is None: diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 24e199d1b64..310c927fd25 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -433,7 +433,7 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out - assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 + assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() await check_worker_cleanup(a) @@ -460,7 +460,7 @@ async def test_restarting_does_not_log_p2p_failed(c, s, a, b): await b.close() await out - assert not s.events["p2p"] + assert not s.get_events("p2p") await c.close() await check_worker_cleanup(a) await check_worker_cleanup(b, closed=True) @@ -831,7 +831,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out - assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 + assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 alive_shuffle.block_inputs_done.set() @@ -994,7 +994,7 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out - assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 + assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() await check_worker_cleanup(a) diff --git a/distributed/stealing.py b/distributed/stealing.py index 952aa55b1e1..1d72e58a22a 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -2,7 +2,7 @@ import asyncio import logging -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Container from functools import partial from math import log2 @@ -106,8 +106,6 @@ def __init__(self, scheduler: Scheduler): ) # `callback_time` is in milliseconds self.scheduler.add_plugin(self) - maxlen = dask.config.get("distributed.admin.low-level-log-length") - self.scheduler.events["stealing"] = deque(maxlen=maxlen) self.count = 0 self.in_flight = {} self.in_flight_occupancy = defaultdict(int) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f589105e672..f35de46bf94 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6461,7 +6461,9 @@ def test_direct_to_workers(s, loop): with Client(s["address"], loop=loop, direct_to_workers=True) as client: future = client.scatter(1) future.result() - resp = client.run_on_scheduler(lambda dask_scheduler: dask_scheduler.events) + resp = client.run_on_scheduler( + lambda dask_scheduler: dask_scheduler.get_events() + ) assert "gather" not in str(resp) diff --git a/distributed/tests/test_event_logging.py b/distributed/tests/test_event_logging.py index 2c74af3ca69..43effcb12f5 100644 --- a/distributed/tests/test_event_logging.py +++ b/distributed/tests/test_event_logging.py @@ -1,17 +1,52 @@ from __future__ import annotations import asyncio +from functools import partial from unittest import mock import pytest -from distributed import Client, Nanny, get_worker +from distributed import Client, Nanny, Scheduler, get_worker from distributed.core import error_message +from distributed.diagnostics import SchedulerPlugin +from distributed.metrics import time from distributed.utils_test import captured_logger, gen_cluster +@gen_cluster(nthreads=[]) +async def test_log_event(s): + before = time() + s.log_event("foo", {"action": "test", "value": 1}) + after = time() + assert len(s.get_events("foo")) == 1 + timestamp, event = s.get_events("foo")[0] + assert before <= timestamp <= after + assert event == {"action": "test", "value": 1} + + +@gen_cluster(nthreads=[]) +async def test_log_events(s): + s.log_event("foo", {"action": "test", "value": 1}) + s.log_event(["foo", "bar"], {"action": "test", "value": 2}) + + actual = [event for _, event in s.get_events("foo")] + assert actual == [{"action": "test", "value": 1}, {"action": "test", "value": 2}] + + actual = [event for _, event in s.get_events("bar")] + assert actual == [{"action": "test", "value": 2}] + + actual = { + topic: [event for _, event in events] + for topic, events in s.get_events().items() + } + assert actual == { + "foo": [{"action": "test", "value": 1}, {"action": "test", "value": 2}], + "bar": [{"action": "test", "value": 2}], + } + + @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_log_event(c, s, a): +async def test_log_event_e2e(c, s, a): # Log an event from inside a task def foo(): get_worker().log_event("topic1", {"foo": "bar"}) @@ -54,7 +89,7 @@ def handler(event): c.subscribe_topic("test-topic", get_event_handler(1)) c2.subscribe_topic("test-topic", get_event_handler(2)) - while len(s.event_subscriber["test-topic"]) != 2: + while len(s._broker._topics["test-topic"].subscribers) != 2: await asyncio.sleep(0.01) with captured_logger("distributed.client") as logger: @@ -77,7 +112,7 @@ def user_event_handler(event): c.subscribe_topic("test-topic", user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {"important": "event"}) @@ -91,12 +126,12 @@ def user_event_handler(event): c.unsubscribe_topic("test-topic") - while s.event_subscriber["test-topic"]: + while s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {"forget": "me"}) - while len(s.events["test-topic"]) == 1: + while len(s.get_events("test-topic")) == 1: await asyncio.sleep(0.01) assert len(log) == 1 @@ -107,7 +142,7 @@ async def async_user_event_handler(event): c.subscribe_topic("test-topic", async_user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {"async": "event"}) @@ -139,7 +174,7 @@ async def user_event_handler(event): await asyncio.sleep(0.5) c.subscribe_topic("test-topic", user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {}) @@ -148,6 +183,62 @@ async def user_event_handler(event): assert exc_info is not None +@gen_cluster(nthreads=[]) +async def test_topic_subscribe_unsubscribe(s): + async with Client(s.address, asynchronous=True) as c1, Client( + s.address, asynchronous=True + ) as c2: + + def event_handler(recorded_events, event): + _, msg = event + recorded_events.append(msg) + + c1_events = [] + c1.subscribe_topic("foo", partial(event_handler, c1_events)) + while not s._broker._topics["foo"].subscribers: + await asyncio.sleep(0.01) + s.log_event("foo", {"value": 1}) + + c2_events = [] + c2.subscribe_topic("foo", partial(event_handler, c2_events)) + c2.subscribe_topic("bar", partial(event_handler, c2_events)) + + while ( + not s._broker._topics["bar"].subscribers + and len(s._broker._topics["foo"].subscribers) < 2 + ): + await asyncio.sleep(0.01) + + s.log_event("foo", {"value": 2}) + s.log_event("bar", {"value": 3}) + + c2.unsubscribe_topic("foo") + + while len(s._broker._topics["foo"].subscribers) > 1: + await asyncio.sleep(0.01) + + s.log_event("foo", {"value": 4}) + s.log_event("bar", {"value": 5}) + + c1.unsubscribe_topic("foo") + + while s._broker._topics["foo"].subscribers: + await asyncio.sleep(0.01) + + s.log_event("foo", {"value": 6}) + s.log_event("bar", {"value": 7}) + + c2.unsubscribe_topic("bar") + + while s._broker._topics["bar"].subscribers: + await asyncio.sleep(0.01) + + s.log_event("bar", {"value": 8}) + + assert c1_events == [{"value": 1}, {"value": 2}, {"value": 4}] + assert c2_events == [{"value": 2}, {"value": 3}, {"value": 5}, {"value": 7}] + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_events_all_servers_use_same_channel(c, s, a): """Ensure that logs from all server types (scheduler, worker, nanny) @@ -160,7 +251,7 @@ def user_event_handler(event): c.subscribe_topic("test-topic", user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) async with Nanny(s.address) as n: @@ -208,17 +299,18 @@ class C: @gen_cluster(client=True, config={"distributed.admin.low-level-log-length": 3}) async def test_configurable_events_log_length(c, s, a, b): s.log_event("test", "dummy message 1") - assert len(s.events["test"]) == 1 + assert len(s.get_events("test")) == 1 s.log_event("test", "dummy message 2") s.log_event("test", "dummy message 3") - assert len(s.events["test"]) == 3 + assert len(s.get_events("test")) == 3 + assert s._broker._topics["test"].count == 3 # adding a fourth message will drop the first one and length stays at 3 s.log_event("test", "dummy message 4") - assert len(s.events["test"]) == 3 - assert s.events["test"][0][1] == "dummy message 2" - assert s.events["test"][1][1] == "dummy message 3" - assert s.events["test"][2][1] == "dummy message 4" + assert len(s.get_events("test")) == 3 + assert s._broker._topics["test"].count == 4 + events = [event for _, event in s.get_events("test")] + assert events == ["dummy message 2", "dummy message 3", "dummy message 4"] @gen_cluster(client=True, nthreads=[]) @@ -282,3 +374,44 @@ class C: "worker": a.address, }, ] == [msg[1] for msg in s.get_events("test-topic")] + + +@gen_cluster(client=True) +async def test_log_event_plugin(c, s, a, b): + class EventPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + self.scheduler._recorded_events = list() # type: ignore + + def log_event(self, topic, msg): + self.scheduler._recorded_events.append((topic, msg)) + + await c.register_plugin(EventPlugin()) + + def f(): + get_worker().log_event("foo", 123) + + await c.submit(f) + + assert ("foo", 123) in s._recorded_events + + +@gen_cluster(client=True) +async def test_log_event_plugin_multiple_topics(c, s, a, b): + class EventPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + self.scheduler._recorded_events = list() # type: ignore + + def log_event(self, topic, msg): + self.scheduler._recorded_events.append((topic, msg)) + + await c.register_plugin(EventPlugin()) + + def f(): + get_worker().log_event(["foo", "bar"], 123) + + await c.submit(f) + + assert ("foo", 123) in s._recorded_events + assert ("bar", 123) in s._recorded_events diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 1314ea8e277..52073c13d2b 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -554,7 +554,7 @@ async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol): ) as n: await n.process.stopped.wait() # Check that the scheduler has been notified about the closed worker - assert "remove-worker" in str(s.events) + assert "remove-worker" in str(s.get_events()) class BrokenWorker(worker.Worker): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7c2cc391ad2..9aaea1288a4 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -865,39 +865,39 @@ async def test_remove_worker_by_name_from_scheduler(s, a, b): @gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "500 ms"}) async def test_clear_events_worker_removal(s, a, b): - assert a.address in s.events + assert a.address in s._broker._topics assert a.address in s.workers - assert b.address in s.events + assert b.address in s._broker._topics assert b.address in s.workers await s.remove_worker(address=a.address, stimulus_id="test") # Shortly after removal, the events should still be there - assert a.address in s.events + assert s.get_events(a.address) assert a.address not in s.workers s.validate_state() start = time() - while a.address in s.events: + while s.get_events(a.address): await asyncio.sleep(0.01) assert time() < start + 2 - assert b.address in s.events + assert b.address in s._broker._topics @gen_cluster( config={"distributed.scheduler.events-cleanup-delay": "10 ms"}, client=True ) async def test_clear_events_client_removal(c, s, a, b): - assert c.id in s.events + assert s.get_events(c.id) s.remove_client(c.id) - assert c.id in s.events + assert s.get_events(c.id) assert c.id not in s.clients assert c not in s.clients s.remove_client(c.id) # If it doesn't reconnect after a given time, the events log should be cleared start = time() - while c.id in s.events: + while s.get_events(c.id): await asyncio.sleep(0.01) assert time() < start + 2 @@ -2108,14 +2108,16 @@ def g(_, ev1, ev2): await ev2.set() -@pytest.mark.slow +# @pytest.mark.slow @gen_cluster( client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False} ) async def test_log_tasks_during_restart(c, s, a, b): future = c.submit(sys.exit, 0) await wait(future) - assert "exit" in str(s.events) + assert "exit" in str( + {name: topic.events for name, topic in s._broker._topics.items()} + ) @gen_cluster(client=True) @@ -4436,7 +4438,7 @@ def block(x, event): await event.set() await c.gather(futs) - assert "TaskState" not in str(s.events) + assert not any("TaskState" in str(event) for event in s.get_events()) @gen_cluster(nthreads=[("", 1)]) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 539923dd1cc..3976857c9e7 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1178,7 +1178,7 @@ async def test_steal_worker_dies_same_ip(c, s, w0, w1): wsB = s.workers[w1.address] steal.move_task_request(victim_ts, wsA, wsB) - len_before = len(s.events["stealing"]) + len_before = len(s.get_events("stealing")) with freeze_batched_send(w0.batched_stream): while not any( isinstance(event, StealRequestEvent) for event in w0.state.stimulus_log @@ -1208,7 +1208,7 @@ async def test_steal_worker_dies_same_ip(c, s, w0, w1): assert hash(wsB2) != hash(wsB) # Wait for the steal response to arrive - while len_before == len(s.events["stealing"]): + while len_before == len(s.get_events("stealing")): await asyncio.sleep(0.1) assert victim_ts.processing_on != wsB @@ -1875,5 +1875,5 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): results = [dask.delayed(lambda *args: None)(root, i) for i in range(1000)] futs = c.compute(results) await c.gather(futs) - events = s.events["stealing"] + events = s.get_events("stealing") assert len(events) == 0 diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 7530b82ce48..c750e3faa5c 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -712,7 +712,7 @@ async def test_log_invalid_transitions(c, s, a): with pytest.raises(InvalidTransition): a.handle_stimulus(ev) - while not s.events["invalid-worker-transition"]: + while not s.get_events("invalid-worker-transition"): await asyncio.sleep(0.01) with pytest.raises(Exception) as info: @@ -737,7 +737,7 @@ async def test_log_invalid_worker_task_state(c, s, a): with pytest.raises(InvalidTaskState): a.validate_state() - while not s.events["invalid-worker-task-state"]: + while not s.get_events("invalid-worker-task-state"): await asyncio.sleep(0.01) with pytest.raises(Exception) as info: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3ea579487c5..117ad004062 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2906,7 +2906,7 @@ async def test_worker_status_sync(s, a): while ws.status != Status.closed: await asyncio.sleep(0.01) - events = [ev for _, ev in s.events[ws.address] if ev["action"] != "heartbeat"] + events = [ev for _, ev in s.get_events(ws.address) if ev["action"] != "heartbeat"] for ev in events: if "stimulus_id" in ev: # Strip timestamp ev["stimulus_id"] = ev["stimulus_id"].rsplit("-", 1)[0] @@ -2963,7 +2963,7 @@ async def test_log_remove_worker(c, s, a, b): # Scattered task z = await c.scatter({"z": 3}, workers=a.address) - s.events.clear() + s._broker.truncate() with captured_logger("distributed.scheduler", level=logging.INFO) as log: # Successful graceful shutdown @@ -2999,7 +2999,7 @@ async def test_log_remove_worker(c, s, a, b): "Lost all workers", ] - events = {topic: [ev for _, ev in evs] for topic, evs in s.events.items()} + events = {topic: [ev for _, ev in evs] for topic, evs in s.get_events().items()} for evs in events.values(): for ev in evs: if ev["action"] == "retire-workers": diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 78943226bff..1fd59b5525a 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -806,24 +806,25 @@ async def start_cluster( def check_invalid_worker_transitions(s: Scheduler) -> None: - if not s.events.get("invalid-worker-transition"): + if not s.get_events("invalid-worker-transition"): return - for _, msg in s.events["invalid-worker-transition"]: + for _, msg in s.get_events("invalid-worker-transition"): worker = msg.pop("worker") print("Worker:", worker) print(InvalidTransition(**msg)) raise ValueError( - "Invalid worker transitions found", len(s.events["invalid-worker-transition"]) + "Invalid worker transitions found", + len(s.get_events("invalid-worker-transition")), ) def check_invalid_task_states(s: Scheduler) -> None: - if not s.events.get("invalid-worker-task-state"): + if not s.get_events("invalid-worker-task-state"): return - for _, msg in s.events["invalid-worker-task-state"]: + for _, msg in s.get_events("invalid-worker-task-state"): print("Worker:", msg["worker"]) print("State:", msg["state"]) for line in msg["story"]: @@ -833,10 +834,10 @@ def check_invalid_task_states(s: Scheduler) -> None: def check_worker_fail_hard(s: Scheduler) -> None: - if not s.events.get("worker-fail-hard"): + if not s.get_events("worker-fail-hard"): return - for _, msg in s.events["worker-fail-hard"]: + for _, msg in s.get_events("worker-fail-hard"): msg = msg.copy() worker = msg.pop("worker") msg["exception"] = deserialize(msg["exception"].header, msg["exception"].frames)