Skip to content

Commit

Permalink
Return results in restart_workers
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed Mar 1, 2023
1 parent 40456b4 commit 5359651
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
14 changes: 11 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,8 +3500,8 @@ def restart(self, timeout=no_default, wait_for_workers=True):

async def _restart_workers(
self, workers: list[str], timeout: int | float | None = None
):
results = await self.scheduler.broadcast(
) -> dict[str, str]:
results: dict[str, str] = await self.scheduler.broadcast(
msg={"op": "restart", "timeout": timeout}, workers=workers, nanny=True
)
timeout_workers = {
Expand All @@ -3511,8 +3511,11 @@ async def _restart_workers(
raise TimeoutError(
f"The following workers failed to restart with {timeout} seconds: {list(timeout_workers.keys())}"
)
return results

def restart_workers(self, workers: list[str], timeout: int | float | None = None):
def restart_workers(
self, workers: list[str], timeout: int | float | None = None
) -> dict[str, str]:
"""Restart a specified set of workers
.. note::
Expand All @@ -3528,6 +3531,11 @@ def restart_workers(self, workers: list[str], timeout: int | float | None = None
timeout : int | float | None
Number of seconds to wait
Returns
-------
dict[str, str]
Mapping of worker and restart status.
Notes
-----
This method differs from :meth:`Client.restart` in that this method
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4803,7 +4803,8 @@ async def test_restart_workers(c, s, a, b):
assert await c.compute(x.sum()) == size

# Restart a single worker
await c.restart_workers(workers=[a.worker_address])
results = await c.restart_workers(workers=[a.worker_address])
assert results[a.worker_address] == "OK"
assert set(s.workers) == {a.worker_address, b.worker_address}

# Make sure worker start times are as expected
Expand Down

0 comments on commit 5359651

Please sign in to comment.