Skip to content

Commit

Permalink
Do not allow closing workers to be awaited again (#5910)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored May 5, 2022
1 parent 7bd6442 commit 2286896
Show file tree
Hide file tree
Showing 18 changed files with 270 additions and 121 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ repos:
- types-psutil
- types-setuptools
# Typed libraries
- numpy
- dask
- numpy
- pytest
- tornado
- zict
- pyarrow
5 changes: 3 additions & 2 deletions distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ async def run_actor_function_on_worker():
if self._future and not self._future.done():
await self._future
return await run_actor_function_on_worker()
else: # pragma: no cover
raise OSError("Unable to contact Actor's worker")
else:
exc = OSError("Unable to contact Actor's worker")
return _Error(exc)
if result["status"] == "OK":
return _OK(result["result"])
return _Error(result["exception"])
Expand Down
14 changes: 12 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,11 +1315,21 @@ async def _wait_for_workers(self, n_workers=0, timeout=None):
deadline = time() + parse_timedelta(timeout)
else:
deadline = None
while n_workers and len(info["workers"]) < n_workers:

def running_workers(info):
return len(
[
ws
for ws in info["workers"].values()
if ws["status"] == Status.running.name
]
)

while n_workers and running_workers(info) < n_workers:
if deadline and time() > deadline:
raise TimeoutError(
"Only %d/%d workers arrived after %s"
% (len(info["workers"]), n_workers, timeout)
% (running_workers(info), n_workers, timeout)
)
await asyncio.sleep(0.1)
info = await self.scheduler.identity()
Expand Down
98 changes: 63 additions & 35 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from distributed.metrics import time
from distributed.system_monitor import SystemMonitor
from distributed.utils import (
TimeoutError,
get_traceback,
has_keyword,
is_coroutine_function,
Expand Down Expand Up @@ -71,11 +70,6 @@ class Status(Enum):


Status.lookup = {s.name: s for s in Status} # type: ignore
Status.ANY_RUNNING = { # type: ignore
Status.running,
Status.paused,
Status.closing_gracefully,
}


class RPCClosed(IOError):
Expand Down Expand Up @@ -168,6 +162,7 @@ def __init__(
timeout=None,
io_loop=None,
):
self._status = Status.init
self.handlers = {
"identity": self.identity,
"echo": self.echo,
Expand Down Expand Up @@ -257,7 +252,8 @@ def set_thread_ident():

self.io_loop.add_callback(set_thread_ident)
self._startup_lock = asyncio.Lock()
self.status = Status.undefined
self.__startup_exc: Exception | None = None
self.__started = asyncio.Event()

self.rpc = ConnectionPool(
limit=connection_limit,
Expand Down Expand Up @@ -289,31 +285,48 @@ async def finished(self):
await self._event_finished.wait()

def __await__(self):
async def _():
timeout = getattr(self, "death_timeout", 0)
async with self._startup_lock:
if self.status in Status.ANY_RUNNING:
return self
if timeout:
try:
await asyncio.wait_for(self.start(), timeout=timeout)
self.status = Status.running
except Exception:
await self.close(timeout=1)
raise TimeoutError(
"{} failed to start in {} seconds".format(
type(self).__name__, timeout
)
)
else:
await self.start()
self.status = Status.running
return self
return self.start().__await__()

return _().__await__()
async def start_unsafe(self):
"""Attempt to start the server. This is not idempotent and not protected against concurrent startup attempts.
async def start(self):
This is intended to be overwritten or called by subclasses. For a safe
startup, please use ``Server.start`` instead.
If ``death_timeout`` is configured, we will require this coroutine to
finish before this timeout is reached. If the timeout is reached we will
close the instance and raise an ``asyncio.TimeoutError``
"""
await self.rpc.start()
return self

async def start(self):
async with self._startup_lock:
if self.status == Status.failed:
assert self.__startup_exc is not None
raise self.__startup_exc
elif self.status != Status.init:
return self
timeout = getattr(self, "death_timeout", None)

async def _close_on_failure(exc: Exception):
await self.close()
self.status = Status.failed
self.__startup_exc = exc

try:
await asyncio.wait_for(self.start_unsafe(), timeout=timeout)
except asyncio.TimeoutError as exc:
await _close_on_failure(exc)
raise asyncio.TimeoutError(
f"{type(self).__name__} start timed out after {timeout}s."
) from exc
except Exception as exc:
await _close_on_failure(exc)
raise RuntimeError(f"{type(self).__name__} failed to start.") from exc
self.status = Status.running
self.__started.set()
return self

async def __aenter__(self):
await self
Expand Down Expand Up @@ -382,16 +395,28 @@ def _cycle_ticks(self):
self._tick_interval_observed = (time() - last) / (count or 1)

@property
def address(self):
def address(self) -> str:
"""
The address this Server can be contacted on.
If the server is not up, yet, this raises a ValueError.
"""
if not self._address:
if self.listener is None:
raise ValueError("cannot get address of non-running Server")
self._address = self.listener.contact_address
return self._address

@property
def address_safe(self) -> str:
"""
The address this Server can be contacted on.
If the server is not up, yet, this returns a ``"not-running"``.
"""
try:
return self.address
except ValueError:
return "not-running"

@property
def listen_address(self):
"""
Expand Down Expand Up @@ -480,6 +505,7 @@ async def handle_comm(self, comm):

logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op

await self
try:
while True:
Expand Down Expand Up @@ -650,11 +676,13 @@ async def handle_stream(self, comm, extra=None, every_cycle=()):
def close(self):
for pc in self.periodic_callbacks.values():
pc.stop()
self.__stopped = True
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
yield future

if not self.__stopped:
self.__stopped = True
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
yield future
for i in range(20):
# If there are still handlers running at this point, give them a
# second to finish gracefully themselves, otherwise...
Expand Down
7 changes: 2 additions & 5 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,7 @@ async def _correct_state_internal(self):
for w in to_close
if w in self.workers
]
await asyncio.wait(tasks)
for task in tasks: # for tornado gen.coroutine support
with suppress(RuntimeError):
await task
await asyncio.gather(*tasks)
for name in to_close:
if name in self.workers:
del self.workers[name]
Expand Down Expand Up @@ -417,7 +414,7 @@ async def _close(self):

await self.scheduler.close()
for w in self._created:
assert w.status == Status.closed, w.status
assert w.status in {Status.closed, Status.failed}, w.status

if hasattr(self, "_old_logging_level"):
silence_logging(self._old_logging_level)
Expand Down
21 changes: 11 additions & 10 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class Nanny(ServerNode):

_instances: ClassVar[weakref.WeakSet[Nanny]] = weakref.WeakSet()
process = None
status = Status.undefined
memory_manager: NannyMemoryManager

# Inputs to parse_ports()
Expand Down Expand Up @@ -269,7 +268,6 @@ def __init__(

self._listen_address = listen_address
Nanny._instances.add(self)
self.status = Status.init

# Deprecated attributes; use Nanny.memory_manager.<name> instead
memory_limit = DeprecatedMemoryManagerAttribute()
Expand Down Expand Up @@ -309,10 +307,10 @@ def local_dir(self):
warnings.warn("The local_dir attribute has moved to local_directory")
return self.local_directory

async def start(self):
async def start_unsafe(self):
"""Start nanny, start local process, start watching"""

await super().start()
await super().start_unsafe()

ports = parse_ports(self._start_port)
for port in ports:
Expand All @@ -337,7 +335,7 @@ async def start(self):
break
else:
raise ValueError(
f"Could not start Nanny on host {self._start_host}"
f"Could not start Nanny on host {self._start_host} "
f"with port {self._start_port}"
)

Expand All @@ -352,11 +350,12 @@ async def start(self):

logger.info(" Start Nanny at: %r", self.address)
response = await self.instantiate()
if response == Status.running:
assert self.worker_address
self.status = Status.running
else:

if response != Status.running:
await self.close()
return

assert self.worker_address

self.start_periodic_callbacks()

Expand Down Expand Up @@ -571,7 +570,9 @@ async def close(self, comm=None, timeout=5, report=None):

self.status = Status.closing
logger.info(
f"Closing Nanny at {self.address!r}. Report closure to scheduler: {report}"
"Closing Nanny at %r. Report closure to scheduler: %s",
self.address_safe,
report,
)

for preload in self.preloads:
Expand Down
9 changes: 4 additions & 5 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def identity(self) -> dict[str, Any]:
"last_seen": self.last_seen,
"services": self.services,
"metrics": self.metrics,
"status": self.status.name,
"nanny": self.nanny,
**self.extra,
}
Expand Down Expand Up @@ -3235,15 +3236,14 @@ def __init__(
setproctitle("dask-scheduler [not started]")
Scheduler._instances.add(self)
self.rpc.allow_offload = False
self.status = Status.undefined

##################
# Administration #
##################

def __repr__(self):
return (
f"<Scheduler {self.address!r}, "
f"<Scheduler {self.address_safe!r}, "
f"workers: {len(self.workers)}, "
f"cores: {self.total_nthreads}, "
f"tasks: {len(self.tasks)}>"
Expand Down Expand Up @@ -3376,10 +3376,9 @@ def get_worker_service_addr(
else:
return ws.host, port

async def start(self):
async def start_unsafe(self):
"""Clear out old state and restart all running coroutines"""
await super().start()
assert self.status != Status.running
await super().start_unsafe()

enable_gc_diagnosis()

Expand Down
1 change: 0 additions & 1 deletion distributed/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ async def test_failed_worker(c, s, a, b):

assert "actor" in str(info.value).lower()
assert "worker" in str(info.value).lower()
assert "lost" in str(info.value).lower()


@gen_cluster(client=True)
Expand Down
Loading

0 comments on commit 2286896

Please sign in to comment.