Skip to content

Commit

Permalink
remove deprecated code calls to IOLoop.make_current() (#7240)
Browse files Browse the repository at this point in the history
Co-authored-by: Hendrik Makait <hendrik.makait@gmail.com>
  • Loading branch information
graingert and hendrikmakait authored Nov 9, 2022
1 parent 945a847 commit 88515db
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 136 deletions.
226 changes: 111 additions & 115 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import asyncio
import contextlib
import errno
import functools
import logging
import multiprocessing
import os
import shutil
import tempfile
Expand All @@ -11,14 +14,12 @@
import warnings
import weakref
from collections.abc import Collection
from contextlib import suppress
from inspect import isawaitable
from queue import Empty
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, ClassVar, Literal
from typing import TYPE_CHECKING, Callable, ClassVar, Literal

from toolz import merge
from tornado import gen
from tornado.ioloop import IOLoop

import dask
Expand All @@ -45,7 +46,6 @@
from distributed.protocol import pickle
from distributed.security import Security
from distributed.utils import (
TimeoutError,
get_ip,
get_mp_context,
json_load_robust,
Expand Down Expand Up @@ -303,14 +303,15 @@ async def _unregister(self, timeout=10):
if worker_address is None:
return

allowed_errors = (TimeoutError, CommClosedError, EnvironmentError, RPCClosed)
with suppress(allowed_errors):
try:
await asyncio.wait_for(
self.scheduler.unregister(
address=self.worker_address, stimulus_id=f"nanny-close-{time()}"
),
timeout,
)
except (asyncio.TimeoutError, CommClosedError, OSError, RPCClosed):
pass

@property
def worker_address(self):
Expand Down Expand Up @@ -425,7 +426,7 @@ async def instantiate(self) -> Status:
result = await asyncio.wait_for(
self.process.start(), self.death_timeout
)
except TimeoutError:
except asyncio.TimeoutError:
logger.error(
"Timed out connecting Nanny '%s' to scheduler '%s'",
self,
Expand Down Expand Up @@ -496,7 +497,7 @@ async def _():

try:
await asyncio.wait_for(_(), timeout)
except TimeoutError:
except asyncio.TimeoutError:
logger.error(
f"Restart timed out after {timeout}s; returning before finished"
)
Expand Down Expand Up @@ -679,18 +680,18 @@ async def start(self) -> Status:
uid = uuid.uuid4().hex

self.process = AsyncProcess(
target=self._run,
name="Dask Worker process (from Nanny)",
kwargs=dict(
worker_kwargs=self.worker_kwargs,
target=functools.partial(
self._run,
silence_logs=self.silence_logs,
init_result_q=self.init_result_q,
child_stop_q=self.child_stop_q,
uid=uid,
Worker=self.Worker,
worker_factory=functools.partial(self.Worker, **self.worker_kwargs),
env=self.env,
config=self.config,
),
name="Dask Worker process (from Nanny)",
kwargs=dict(),
)
self.process.daemon = dask.config.get("distributed.worker.daemon", default=True)
self.process.set_exit_callback(self._on_exit)
Expand Down Expand Up @@ -860,86 +861,66 @@ async def _wait_until_connected(self, uid):
@classmethod
def _run(
cls,
worker_kwargs,
silence_logs,
init_result_q,
child_stop_q,
uid,
env,
config,
Worker,
): # pragma: no cover
try:
os.environ.update(env)
dask.config.refresh()
dask.config.set(config)

from dask.multiprocessing import default_initializer

default_initializer()

if silence_logs:
logger.setLevel(silence_logs)

IOLoop.clear_instance()
loop = IOLoop()
loop.make_current()
worker = Worker(**worker_kwargs)

async def do_stop(
timeout=5, executor_wait=True, reason="workerprocess-stop"
):
try:
await worker.close(
nanny=False,
executor_wait=executor_wait,
timeout=timeout,
reason=reason,
)
finally:
loop.stop()

def watch_stop_q():
"""
Wait for an incoming stop message and then stop the
worker cleanly.
"""
try:
msg = child_stop_q.get()
except (TypeError, OSError, EOFError):
logger.error("Worker process died unexpectedly")
msg = {"op": "stop"}
finally:
child_stop_q.close()
assert msg["op"] == "stop", msg
del msg["op"]
loop.add_callback(do_stop, **msg)

thread = threading.Thread(
target=watch_stop_q, name="Nanny stop queue watch"
silence_logs: bool,
init_result_q: multiprocessing.Queue,
child_stop_q: multiprocessing.Queue,
uid: str,
env: dict,
config: dict,
worker_factory: Callable[[], Worker],
) -> None: # pragma: no cover
async def do_stop(
*,
worker: Worker,
timeout: float = 5,
executor_wait: bool = True,
reason: str = "workerprocess-stop",
) -> None:
await worker.close(
nanny=False,
executor_wait=executor_wait,
timeout=timeout,
reason=reason,
)
thread.daemon = True
thread.start()

async def run():
"""
Try to start worker and inform parent of outcome.
"""
try:
await worker
except Exception as e:
logger.exception("Failed to start worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good
# measure)
sync_sleep(cls._init_msg_interval * 2)
else:
def watch_stop_q(loop: IOLoop, worker: Worker) -> None:
"""
Wait for an incoming stop message and then stop the
worker cleanly.
"""
try:
msg = child_stop_q.get()
except (TypeError, OSError, EOFError):
logger.error("Worker process died unexpectedly")
msg = {"op": "stop"}
finally:
child_stop_q.close()
assert msg["op"] == "stop", msg
del msg["op"]
loop.add_callback(do_stop, worker=worker, **msg)

async def run() -> None:
"""
Try to start worker and inform parent of outcome.
"""
failure_type: str | None = "initialize"
try:
worker = worker_factory()
failure_type = "start"
thread = threading.Thread(
target=functools.partial(
watch_stop_q,
worker=worker,
loop=IOLoop.current(),
),
name="Nanny stop queue watch",
daemon=True,
)
thread.start()
stack.callback(thread.join, timeout=2)
async with worker:
failure_type = None

try:
assert worker.address
except ValueError:
Expand All @@ -955,34 +936,49 @@ async def run():
init_result_q.close()
await worker.finished()
logger.info("Worker closed")
except Exception as e:
if failure_type is None:
raise

except Exception as e:
logger.exception("Failed to initialize Worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least one
# interval for the outside to pick up this message. Otherwise we
# arrive in a race condition where the process cleanup wipes the
# queue before the exception can be properly handled. See also
# WorkerProcess._wait_until_connected (the 2 is for good measure)
sync_sleep(cls._init_msg_interval * 2)
else:
try:
loop.run_sync(run)
except (TimeoutError, gen.TimeoutError):
# Loop was stopped before wait_until_closed() returned, ignore
pass
except KeyboardInterrupt:
# At this point the loop is not running thus we have to run
# do_stop() explicitly.
loop.run_sync(do_stop)
finally:
with suppress(ValueError):
logger.exception(f"Failed to {failure_type} worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 3 is for good
# measure)
sync_sleep(cls._init_msg_interval * 3)

with contextlib.ExitStack() as stack:

@stack.callback
def close_stop_q() -> None:
try:
child_stop_q.put({"op": "stop"}) # usually redundant
with suppress(ValueError):
except ValueError:
pass

try:
child_stop_q.close() # usually redundant
except ValueError:
pass
child_stop_q.join_thread()
thread.join(timeout=2)

os.environ.update(env)
dask.config.refresh()
dask.config.set(config)

from dask.multiprocessing import default_initializer

default_initializer()

if silence_logs:
logger.setLevel(silence_logs)

asyncio.run(run())


def _get_env_variables(config_key: str) -> dict[str, str]:
Expand Down
23 changes: 20 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@
map_varying,
nodebug,
popen,
pristine_loop,
randominc,
save_sys_modules,
slowadd,
Expand Down Expand Up @@ -2206,8 +2205,26 @@ async def test_multi_client(s, a, b):
await asyncio.sleep(0.01)


@contextmanager
def _pristine_loop():
IOLoop.clear_instance()
IOLoop.clear_current()
loop = IOLoop()
loop.make_current()
assert IOLoop.current() is loop
try:
yield loop
finally:
try:
loop.close(all_fds=True)
except (KeyError, ValueError):
pass
IOLoop.clear_instance()
IOLoop.clear_current()


def long_running_client_connection(address):
with pristine_loop():
with _pristine_loop():
c = Client(address)
x = c.submit(lambda x: x + 1, 10)
x.result()
Expand Down Expand Up @@ -5601,7 +5618,7 @@ async def close():
async with client:
pass

with pristine_loop() as loop:
with _pristine_loop() as loop:
with pytest.warns(
DeprecationWarning,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is deprecated",
Expand Down
18 changes: 0 additions & 18 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,6 @@ async def run():
return


@contextmanager
def pristine_loop():
IOLoop.clear_instance()
IOLoop.clear_current()
loop = IOLoop()
loop.make_current()
assert IOLoop.current() is loop
try:
yield loop
finally:
try:
loop.close(all_fds=True)
except (KeyError, ValueError):
pass
IOLoop.clear_instance()
IOLoop.clear_current()


original_config = copy.deepcopy(dask.config.config)


Expand Down

0 comments on commit 88515db

Please sign in to comment.