diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e5505f7341f..a7dd6f68617 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5144,10 +5144,8 @@ async def restart(self, client=None, timeout=30): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() - while ( - time() < start + 10 - and len(self.workers) < n_workers - and initial_workers & self._get_worker_ids() + while time() < start + 10 and ( + len(self.workers) < n_workers or initial_workers & self._get_worker_ids() ): await asyncio.sleep(0.01) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 09b3cc6a12c..6eac6eee5b2 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3494,11 +3494,13 @@ async def test_Client_clears_references_after_restart(c, s, a, b): @pytest.mark.slow -@gen_cluster(Worker=Nanny, client=True, nthreads=[("", 1)]) -async def test_restart_waits_for_new_workers(c, s, a): +@gen_cluster(Worker=Nanny, client=True, nthreads=[("", 1)] * 5) +async def test_restart_waits_for_new_workers(c, s, *workers): initial_workers = set(s.workers) await c.restart() - assert set(s.workers) != initial_workers + assert len(s.workers) == len(initial_workers) + for w in workers: + assert w.address not in s.workers @gen_cluster(Worker=Nanny, client=True)