diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 8bfca906148..2f431da3ae8 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -112,6 +112,9 @@ def add_client(self, scheduler: Scheduler, client: str) -> None: def remove_client(self, scheduler: Scheduler, client: str) -> None: """Run when a client disconnects""" + def log_event(self, name, msg) -> None: + """Run when an event is logged""" + class WorkerPlugin: """Interface to extend the Worker diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index ea18aba764c..59511ba456a 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -1,6 +1,6 @@ import pytest -from distributed import Scheduler, SchedulerPlugin, Worker +from distributed import Scheduler, SchedulerPlugin, Worker, get_worker from distributed.utils_test import gen_cluster, gen_test, inc @@ -178,3 +178,23 @@ def start(self, scheduler): assert "distributed.scheduler.pickle" in msg assert n_plugins == len(s.plugins) + + +@gen_cluster(client=True) +async def test_log_event_plugin(c, s, a, b): + class EventPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + self.scheduler._recorded_events = list() # type: ignore + + def log_event(self, name, msg): + self.scheduler._recorded_events.append((name, msg)) + + await c.register_scheduler_plugin(EventPlugin()) + + def f(): + get_worker().log_event("foo", 123) + + await c.submit(f) + + assert ("foo", 123) in s._recorded_events diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 805609a19ec..58922501b6d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6941,6 +6941,12 @@ def log_event(self, name, msg): self.event_counts[name] += 1 self._report_event(name, event) + for plugin in list(self.plugins.values()): + try: + plugin.log_event(name, msg) + except Exception: + logger.info("Plugin failed with exception", exc_info=True) + def _report_event(self, name, event): for client in self.event_subscriber[name]: self.report( diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 1abb00db554..5878120370e 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -2,6 +2,7 @@ import distributed from distributed import Event, Lock, Worker +from distributed.client import wait from distributed.utils_test import ( _LockedCommPool, assert_story, @@ -264,11 +265,12 @@ async def test_in_flight_lost_after_resumed(c, s, b): block_get_data = asyncio.Lock() in_get_data = asyncio.Event() + await block_get_data.acquire() lock_executing = Lock() def block_execution(lock): with lock: - return + return 1 class BlockedGetData(Worker): async def get_data(self, comm, *args, **kwargs): @@ -281,15 +283,12 @@ async def get_data(self, comm, *args, **kwargs): block_execution, lock_executing, workers=[a.address], - allow_other_workers=True, key="fut1", ) # Ensure fut1 is in memory but block any further execution afterwards to # ensure we control when the recomputation happens - await fut1 + await wait(fut1) await lock_executing.acquire() - in_get_data.clear() - await block_get_data.acquire() fut2 = c.submit(inc, fut1, workers=[b.address], key="fut2") # This ensures that B already fetches the task, i.e. after this the task @@ -298,6 +297,7 @@ async def get_data(self, comm, *args, **kwargs): assert fut1.key in b.tasks assert b.tasks[fut1.key].state == "flight" + s.set_restrictions({fut1.key: [a.address, b.address]}) # It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled # to be recomputed on B await s.remove_worker(a.address, "foo", close=False, safe=True) @@ -396,3 +396,59 @@ def block_execution(event, lock): await lock_executing.release() assert await fut2 == 2 + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3): + # See https://github.com/dask/distributed/pull/6327#discussion_r872231090 + block_get_data_1 = asyncio.Lock() + enter_get_data_1 = asyncio.Event() + await block_get_data_1.acquire() + + class BlockGetDataWorker(Worker): + def __init__(self, *args, get_data_event, get_data_lock, **kwargs): + self._get_data_event = get_data_event + self._get_data_lock = get_data_lock + super().__init__(*args, **kwargs) + + async def get_data(self, comm, *args, **kwargs): + self._get_data_event.set() + async with self._get_data_lock: + return await super().get_data(comm, *args, **kwargs) + + async with await BlockGetDataWorker( + s.address, + get_data_event=enter_get_data_1, + get_data_lock=block_get_data_1, + name="w1", + ) as w1: + + f1 = c.submit(inc, 1, key="f1", workers=[w1.address]) + f2 = c.submit(inc, 2, key="f2", workers=[w1.address]) + f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address]) + + await wait(f3) + f4 = c.submit(inc, f3, key="f4", workers=[w2.address]) + + await enter_get_data_1.wait() + s.set_restrictions( + { + f1.key: {w3.address}, + f2.key: {w3.address}, + f3.key: {w2.address}, + } + ) + await s.remove_worker(w1.address, "stim-id") + + await wait_for_state(f3.key, "resumed", w2) + assert_story( + w2.log, + [ + (f3.key, "flight", "released", "cancelled", {}), + # ... + (f3.key, "cancelled", "waiting", "resumed", {}), + ], + ) + # w1 closed + + assert await f4 == 6 diff --git a/distributed/worker.py b/distributed/worker.py index 3a9efd296d4..ef752a8dee4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -828,6 +828,7 @@ def __init__( # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 pc = PeriodicCallback(self.find_missing, 1000) # type: ignore + self._find_missing_running = False self.periodic_callbacks["find-missing"] = pc self._address = contact_address @@ -1982,13 +1983,6 @@ def handle_compute_task( self.transitions(recommendations, stimulus_id=stimulus_id) self._handle_instructions(instructions) - if self.validate: - # All previously unknown tasks that were created above by - # ensure_tasks_exists() have been transitioned to fetch or flight - assert all( - ts2.state != "released" for ts2 in (ts, *ts.dependencies) - ), self.story(ts, *ts.dependencies) - ######################## # Worker State Machine # ######################## @@ -3431,9 +3425,10 @@ def _readd_busy_worker(self, worker: str) -> None: @log_errors async def find_missing(self) -> None: - if not self._missing_dep_flight: + if self._find_missing_running or not self._missing_dep_flight: return try: + self._find_missing_running = True if self.validate: for ts in self._missing_dep_flight: assert not ts.who_has @@ -3451,6 +3446,7 @@ async def find_missing(self) -> None: self.transitions(recommendations, stimulus_id=stimulus_id) finally: + self._find_missing_running = False # This is quite arbitrary but the heartbeat has scaling implemented self.periodic_callbacks[ "find-missing"