From bd61588f690603c63b792085efe3bb6059adfaec Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 27 Jun 2022 19:41:42 +0200 Subject: [PATCH] Ensure client.restart waits for workers to leave --- distributed/scheduler.py | 10 ++++++++-- distributed/tests/test_client.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9eadaf51288..3dadb3039d7 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5083,11 +5083,15 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() + def _get_worker_ids(self) -> set[str]: + return set({ws.server_id for ws in self.workers.values()}) + @log_errors async def restart(self, client=None, timeout=30): """Restart all workers. Reset local state.""" stimulus_id = f"restart-{time()}" - n_workers = len(self.workers) + initial_workers = self._get_worker_ids() + n_workers = len(initial_workers) logger.info("Send lost future signal to clients") for cs in self.clients.values(): @@ -5161,7 +5165,9 @@ async def restart(self, client=None, timeout=30): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() - while time() < start + 10 and len(self.workers) < n_workers: + while time() < start + 10 and ( + len(self.workers) < n_workers or initial_workers & self._get_worker_ids() + ): await asyncio.sleep(0.01) self.report({"op": "restart"}) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 42f019bb983..6907a703b55 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3493,6 +3493,16 @@ async def test_Client_clears_references_after_restart(c, s, a, b): assert key not in c.refcount +@pytest.mark.slow +@gen_cluster(Worker=Nanny, client=True, nthreads=[("", 1)] * 5) +async def test_restart_waits_for_new_workers(c, s, *workers): + initial_workers = set(s.workers) + await c.restart() + assert len(s.workers) == len(initial_workers) + for w in workers: + assert w.address not in s.workers + + @gen_cluster(Worker=Nanny, client=True) async def test_restart_timeout_is_logged(c, s, a, b): with captured_logger(logging.getLogger("distributed.client")) as logger: