Skip to content

Commit

Permalink
Ensure client.restart waits for workers to leave
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jun 27, 2022
1 parent a8eb3b2 commit a905b6d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
12 changes: 10 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5062,11 +5062,15 @@ def clear_task_state(self):
for collection in self._task_state_collections:
collection.clear()

def _get_worker_ids(self) -> set[str]:
return set({ws.server_id for ws in self.workers.values()})

@log_errors
async def restart(self, client=None, timeout=30):
"""Restart all workers. Reset local state."""
stimulus_id = f"restart-{time()}"
n_workers = len(self.workers)
initial_workers = self._get_worker_ids()
n_workers = len(initial_workers)

logger.info("Send lost future signal to clients")
for cs in self.clients.values():
Expand Down Expand Up @@ -5140,7 +5144,11 @@ async def restart(self, client=None, timeout=30):

self.log_event([client, "all"], {"action": "restart", "client": client})
start = time()
while time() < start + 10 and len(self.workers) < n_workers:
while (
time() < start + 10
and len(self.workers) < n_workers
and initial_workers & self._get_worker_ids()
):
await asyncio.sleep(0.01)

self.report({"op": "restart"})
Expand Down
8 changes: 8 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3493,6 +3493,14 @@ async def test_Client_clears_references_after_restart(c, s, a, b):
assert key not in c.refcount


@pytest.mark.slow
@gen_cluster(Worker=Nanny, client=True, nthreads=[("", 1)])
async def test_restart_waits_for_new_workers(c, s, a):
initial_workers = set(s.workers)
await c.restart()
assert set(s.workers) != initial_workers


@gen_cluster(Worker=Nanny, client=True)
async def test_restart_timeout_is_logged(c, s, a, b):
with captured_logger(logging.getLogger("distributed.client")) as logger:
Expand Down

0 comments on commit a905b6d

Please sign in to comment.