From 6b91ec60469d7e93d4894ea99b5db68f5a6824f8 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 20 May 2022 10:04:50 +0200 Subject: [PATCH] Remove wrong assert in handle compute (#6370) --- distributed/tests/test_cancelled_state.py | 56 +++++++++++++++++++++++ distributed/worker.py | 12 ++--- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index a29239db469..5878120370e 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -396,3 +396,59 @@ def block_execution(event, lock): await lock_executing.release() assert await fut2 == 2 + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3): + # See https://github.com/dask/distributed/pull/6327#discussion_r872231090 + block_get_data_1 = asyncio.Lock() + enter_get_data_1 = asyncio.Event() + await block_get_data_1.acquire() + + class BlockGetDataWorker(Worker): + def __init__(self, *args, get_data_event, get_data_lock, **kwargs): + self._get_data_event = get_data_event + self._get_data_lock = get_data_lock + super().__init__(*args, **kwargs) + + async def get_data(self, comm, *args, **kwargs): + self._get_data_event.set() + async with self._get_data_lock: + return await super().get_data(comm, *args, **kwargs) + + async with await BlockGetDataWorker( + s.address, + get_data_event=enter_get_data_1, + get_data_lock=block_get_data_1, + name="w1", + ) as w1: + + f1 = c.submit(inc, 1, key="f1", workers=[w1.address]) + f2 = c.submit(inc, 2, key="f2", workers=[w1.address]) + f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address]) + + await wait(f3) + f4 = c.submit(inc, f3, key="f4", workers=[w2.address]) + + await enter_get_data_1.wait() + s.set_restrictions( + { + f1.key: {w3.address}, + f2.key: {w3.address}, + f3.key: {w2.address}, + } + ) + await s.remove_worker(w1.address, "stim-id") + + await wait_for_state(f3.key, "resumed", w2) + assert_story( + w2.log, + [ + (f3.key, "flight", "released", "cancelled", {}), + # ... + (f3.key, "cancelled", "waiting", "resumed", {}), + ], + ) + # w1 closed + + assert await f4 == 6 diff --git a/distributed/worker.py b/distributed/worker.py index 57a872455fd..84f25221a63 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -827,6 +827,7 @@ def __init__( # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 pc = PeriodicCallback(self.find_missing, 1000) # type: ignore + self._find_missing_running = False self.periodic_callbacks["find-missing"] = pc self._address = contact_address @@ -1983,13 +1984,6 @@ def handle_compute_task( self.transitions(recommendations, stimulus_id=stimulus_id) self._handle_instructions(instructions) - if self.validate: - # All previously unknown tasks that were created above by - # ensure_tasks_exists() have been transitioned to fetch or flight - assert all( - ts2.state != "released" for ts2 in (ts, *ts.dependencies) - ), self.story(ts, *ts.dependencies) - ######################## # Worker State Machine # ######################## @@ -3432,9 +3426,10 @@ def _readd_busy_worker(self, worker: str) -> None: @log_errors async def find_missing(self) -> None: - if not self._missing_dep_flight: + if self._find_missing_running or not self._missing_dep_flight: return try: + self._find_missing_running = True if self.validate: for ts in self._missing_dep_flight: assert not ts.who_has @@ -3452,6 +3447,7 @@ async def find_missing(self) -> None: self.transitions(recommendations, stimulus_id=stimulus_id) finally: + self._find_missing_running = False # This is quite arbitrary but the heartbeat has scaling implemented self.periodic_callbacks[ "find-missing"