Skip to content

Commit

Permalink
Merge remote-tracking branch 'fjetter/ensure_nanny_restart_not_kill_w…
Browse files Browse the repository at this point in the history
…orker' into mindeps-testing
  • Loading branch information
charlesbluca committed Dec 5, 2022
2 parents 8e59a2c + 85562a0 commit 43d4017
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 104 deletions.
55 changes: 32 additions & 23 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def __init__(
self.digests = None
self._ongoing_background_tasks = AsyncTaskGroup()
self._event_finished = asyncio.Event()
self._event_started = asyncio.Event()

self.listeners = []
self.io_loop = self.loop = IOLoop.current()
Expand Down Expand Up @@ -489,6 +490,9 @@ async def finished(self):
"""Wait until the server has finished"""
await self._event_finished.wait()

async def started(self):
await self._event_started.wait()

def __await__(self):
return self.start().__await__()

Expand All @@ -507,30 +511,32 @@ async def start_unsafe(self):

@final
async def start(self):
async with self._startup_lock:
if self.status == Status.failed:
assert self.__startup_exc is not None
raise self.__startup_exc
elif self.status != Status.init:
return self
timeout = getattr(self, "death_timeout", None)

async def _close_on_failure(exc: Exception) -> None:
await self.close()
self.status = Status.failed
self.__startup_exc = exc
if self.status == Status.failed:
assert self.__startup_exc is not None
raise self.__startup_exc
elif self.status != Status.init:
return self

try:
async def _close_on_failure(exc: Exception) -> None:
self._event_started.set()
await self.close()
self.status = Status.failed
self.__startup_exc = exc

timeout = getattr(self, "death_timeout", None)
try:
async with self._startup_lock:
await asyncio.wait_for(self.start_unsafe(), timeout=timeout)
except asyncio.TimeoutError as exc:
await _close_on_failure(exc)
raise asyncio.TimeoutError(
f"{type(self).__name__} start timed out after {timeout}s."
) from exc
except Exception as exc:
await _close_on_failure(exc)
raise RuntimeError(f"{type(self).__name__} failed to start.") from exc
self.status = Status.running
self._event_started.set()
self.status = Status.running
except asyncio.TimeoutError as exc:
await _close_on_failure(exc)
raise asyncio.TimeoutError(
f"{type(self).__name__} start timed out after {timeout}s."
) from exc
except Exception as exc:
await _close_on_failure(exc)
raise RuntimeError(f"{type(self).__name__} failed to start.") from exc
return self

async def __aenter__(self):
Expand Down Expand Up @@ -741,7 +747,7 @@ async def _handle_comm(self, comm):
logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op

await self
await self.started()
try:
while not self.__stopped:
try:
Expand Down Expand Up @@ -940,6 +946,9 @@ async def close(self, timeout=None):
await asyncio.gather(*[comm.close() for comm in list(self._comms)])
finally:
self._event_finished.set()
logger.debug(
f"Closed {type(self).__name__} - {self.address_safe} - {self.id}"
)


def pingpong(comm):
Expand Down
139 changes: 66 additions & 73 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import uuid
import warnings
import weakref
from collections import defaultdict
from collections.abc import Collection
from inspect import isawaitable
from queue import Empty
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, Callable, ClassVar, Literal

from toolz import merge
Expand Down Expand Up @@ -119,6 +119,7 @@ class Nanny(ServerNode):
# Inputs to parse_ports()
_given_worker_port: int | str | Collection[int] | None
_start_port: int | str | Collection[int] | None
_process_callback_received: defaultdict[WorkerProcess, asyncio.Event]

def __init__( # type: ignore[no-untyped-def]
self,
Expand Down Expand Up @@ -223,6 +224,9 @@ def __init__( # type: ignore[no-untyped-def]
self.validate = validate
self.resources = resources

self._instantiate_lock = asyncio.Lock()
self._process_callback_received = defaultdict(asyncio.Event)

self.Worker = Worker if worker_class is None else worker_class

self.pre_spawn_env = _get_env_variables("distributed.nanny.pre-spawn-environ")
Expand Down Expand Up @@ -385,66 +389,50 @@ async def kill(self, timeout: float = 2, reason: str = "nanny-kill") -> None:
return

deadline = time() + timeout
await self.process.kill(reason=reason, timeout=0.8 * (deadline - time()))
proc = self.process
await proc.kill(reason=reason, timeout=0.8 * (deadline - time()))
assert proc.status in (Status.stopped, Status.failed), proc.status
assert proc.stopped.is_set()
await self._process_callback_received[proc].wait()
assert self.process is not proc

async def instantiate(self) -> Status:
"""Start a local worker process
Blocks until the process is up and the scheduler is properly informed
"""
if self.process is None:
worker_kwargs = dict(
scheduler_ip=self.scheduler_addr,
nthreads=self.nthreads,
local_directory=self._original_local_dir,
services=self.services,
nanny=self.address,
name=self.name,
memory_limit=self.memory_manager.memory_limit,
resources=self.resources,
validate=self.validate,
silence_logs=self.silence_logs,
death_timeout=self.death_timeout,
preload=self.preload,
preload_argv=self.preload_argv,
security=self.security,
contact_address=self.contact_address,
)
worker_kwargs.update(self.worker_kwargs)
self.process = WorkerProcess(
worker_kwargs=worker_kwargs,
silence_logs=self.silence_logs,
on_exit=self._on_worker_exit_sync,
worker=self.Worker,
env=self.env,
pre_spawn_env=self.pre_spawn_env,
config=self.config,
)

if self.death_timeout:
try:
result = await asyncio.wait_for(
self.process.start(), self.death_timeout
)
except asyncio.TimeoutError:
logger.error(
"Timed out connecting Nanny '%s' to scheduler '%s'",
self,
self.scheduler_addr,
# The lock is required since there are many possible race conditions due
# to the worker exit callback
async with self._instantiate_lock:
if self.process is None:
worker_kwargs = dict(
scheduler_ip=self.scheduler_addr,
nthreads=self.nthreads,
local_directory=self._original_local_dir,
services=self.services,
nanny=self.address,
name=self.name,
memory_limit=self.memory_manager.memory_limit,
resources=self.resources,
validate=self.validate,
silence_logs=self.silence_logs,
death_timeout=self.death_timeout,
preload=self.preload,
preload_argv=self.preload_argv,
security=self.security,
contact_address=self.contact_address,
)
await self.close(
timeout=self.death_timeout, reason="nanny-instantiate-timeout"
worker_kwargs.update(self.worker_kwargs)
self.process = WorkerProcess(
worker_kwargs=worker_kwargs,
silence_logs=self.silence_logs,
on_exit=self._on_worker_exit_sync,
worker=self.Worker,
env=self.env,
pre_spawn_env=self.pre_spawn_env,
config=self.config,
)
raise

else:
try:
result = await self.process.start()
except Exception:
logger.error("Failed to start process", exc_info=True)
await self.close(reason="nanny-instantiate-failed")
raise
return result
return await self.process.start()

@log_errors
async def plugin_add(self, plugin=None, name=None):
Expand Down Expand Up @@ -519,6 +507,9 @@ def _on_worker_exit_sync(self, exitcode):

@log_errors
async def _on_worker_exit(self, exitcode):
assert self.process
self._process_callback_received[self.process].set()
self.process = None
if self.status not in (
Status.init,
Status.closing,
Expand Down Expand Up @@ -550,6 +541,8 @@ async def _on_worker_exit(self, exitcode):
logger.error(
"Failed to restart worker after its process exited", exc_info=True
)
await self.close(reason="worker-failed-restart")
raise

@property
def pid(self):
Expand Down Expand Up @@ -578,11 +571,14 @@ async def close(
"""
if self.status == Status.closing:
await self.finished()
assert self.status == Status.closed
assert self.status in (Status.closed, Status.failed)

if self.status == Status.closed:
if self.status in (Status.closed, Status.failed):
return "OK"

# Make sure we're not colliding with the startup coro when setting the
# status to closing
await self.started()
self.status = Status.closing
logger.info("Closing Nanny at %r. Reason: %s", self.address_safe, reason)

Expand Down Expand Up @@ -726,6 +722,7 @@ async def start(self) -> Status:
self.running.set()

init_q.close()
init_q.join_thread()

return self.status

Expand Down Expand Up @@ -796,8 +793,12 @@ async def kill(
if self.status == Status.stopping:
await self.stopped.wait()
return
# If the process is not properly up it will not watch the closing queue
# and we may end up leaking this process.
# Therefore wait for it to be properly started before killing it.
if self.status == Status.starting:
await self.running.wait()
assert self.status in (
Status.starting,
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
), self.status
Expand All @@ -817,22 +818,20 @@ async def kill(
"reason": reason,
}
)
await asyncio.sleep(0) # otherwise we get broken pipe errors
queue.close()
queue.join_thread()
del queue

try:
try:
await process.join(wait_timeout)
return
except asyncio.TimeoutError:
pass

logger.warning(
f"Worker process still alive after {wait_timeout} seconds, killing"
)
await process.kill()
await process.join(max(0, deadline - time()))
logger.warning(
f"Worker process still alive after {wait_timeout} seconds, killing"
)
await process.kill()
await process.join(max(0, deadline - time()))
await self.stopped.wait()
except ValueError as e:
if "invalid operation on closed AsyncProcess" in str(e):
return
Expand Down Expand Up @@ -934,6 +933,7 @@ async def run() -> None:
}
)
init_result_q.close()
init_result_q.join_thread()
await worker.finished()
logger.info("Worker closed")
except Exception as e:
Expand All @@ -943,14 +943,7 @@ async def run() -> None:
logger.exception(f"Failed to {failure_type} worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 3 is for good
# measure)
sync_sleep(cls._init_msg_interval * 3)
init_result_q.join_thread()

with contextlib.ExitStack() as stack:

Expand Down
4 changes: 3 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3958,7 +3958,9 @@ async def log_errors(func):
await asyncio.gather(
*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]
)

# Make sure we're not colliding with the startup coro when setting the
# status to closing
await self.started()
self.status = Status.closing

logger.info("Scheduler closing...")
Expand Down
Loading

0 comments on commit 43d4017

Please sign in to comment.