Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

configure asyncio loop using loop_factory kwarg rather than using the set_event_loop_policy #7969

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion distributed/cli/dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from distributed import Scheduler
from distributed._signals import wait_for_signals
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (
enable_proctitle_on_children,
Expand Down Expand Up @@ -246,7 +248,7 @@ async def wait_for_signals_and_close():
logger.info("Stopped scheduler at %r", scheduler.address)

try:
asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())
finally:
logger.info("End scheduler")

Expand Down
4 changes: 3 additions & 1 deletion distributed/cli/dask_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import click
import yaml

from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.deploy.spec import run_spec


Expand Down Expand Up @@ -39,7 +41,7 @@ async def run():
except KeyboardInterrupt:
await asyncio.gather(*(w.close() for w in servers.values()))

asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from distributed import Nanny
from distributed._signals import wait_for_signals
from distributed.comm import get_address_host_port
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.deploy.utils import nprocesses_nthreads
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (
Expand Down Expand Up @@ -443,7 +445,7 @@ async def wait_for_signals_and_close():
[task.result() for task in done]

try:
asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())
except (TimeoutError, asyncio.TimeoutError):
# We already log the exception in nanny / worker. Don't do it again.
if not signal_fired:
Expand Down
6 changes: 5 additions & 1 deletion distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
unparse_host_port,
)
from distributed.comm.registry import backends, get_backend
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.protocol import Serialized, deserialize, serialize, to_serialize
from distributed.utils import get_ip, get_ipv6, get_mp_context, wait_for
Expand Down Expand Up @@ -438,7 +440,9 @@ async def run_with_timeout():
t = asyncio.create_task(func(*args, **kwargs))
return await wait_for(t, timeout=10)

return await asyncio.to_thread(asyncio.run, run_with_timeout())
return await asyncio.to_thread(
asyncio_run, run_with_timeout(), loop_factory=get_loop_factory()
)


@gen_test()
Expand Down
85 changes: 84 additions & 1 deletion distributed/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import random
import sys
import warnings
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar

import tornado

Expand Down Expand Up @@ -48,7 +50,7 @@ def randbytes(*args, **kwargs):
# takes longer than the interval
import datetime
import math
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable
from inspect import isawaitable

from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -182,3 +184,84 @@ def _update_next(self, current_time: float) -> None:
# time.monotonic().
# https://github.com/tornadoweb/tornado/issues/2333
self._next_timeout += callback_time_sec


_T = TypeVar("_T")

if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):

def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)

else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop"
)

if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")

if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
Comment on lines +232 to +233
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not necessary when using a loop_factory?

Copy link
Member Author

@graingert graingert Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the behaviour from 3.11+ when loop_factory was introduced to asyncio.Runner

https://github.com/python/cpython/blob/b03755a2347325a89a48b08fc158419000513bcb/Lib/asyncio/runners.py#L136-L142

it isn't called when the loop_factory is specified because otherwise it would call asyncio.get_event_loop_policy().set_event_loop(loop) on the default policy with a loop from another policy which isn't supported

set_event_loop is currently only used by the posix _UnixDefaultEventLoopPolicy to set an eventloop for listening to child process events with a child process watcher, which is deprecated and not used on windows or uvloop where we use the loop_factory kwarg

if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
Comment on lines +243 to +244
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

loop.close()

def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
22 changes: 11 additions & 11 deletions distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging.config
import os
import sys
from collections.abc import Callable
from typing import Any

import yaml
Expand Down Expand Up @@ -177,7 +178,7 @@ def initialize_logging(config: dict[Any, Any]) -> None:
_initialize_logging_old_style(config)


def initialize_event_loop(config: dict[Any, Any]) -> None:
def get_loop_factory() -> Callable[[], asyncio.AbstractEventLoop] | None:
event_loop = dask.config.get("distributed.admin.event-loop")
if event_loop == "uvloop":
uvloop = import_required(
Expand All @@ -189,19 +190,18 @@ def initialize_event_loop(config: dict[Any, Any]) -> None:
" conda install uvloop\n"
" pip install uvloop",
)
uvloop.install()
elif event_loop in {"asyncio", "tornado"}:
return uvloop.new_event_loop
if event_loop in {"asyncio", "tornado"}:
if sys.platform == "win32":
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
# ProactorEventLoop is not compatible with tornado 6
# fallback to the pre-3.8 default of Selector
# https://github.com/tornadoweb/tornado/issues/2608
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
else:
raise ValueError(
"Expected distributed.admin.event-loop to be in ('asyncio', 'tornado', 'uvloop'), got %s"
% dask.config.get("distributed.admin.event-loop")
)
return asyncio.SelectorEventLoop
return None
raise ValueError(
"Expected distributed.admin.event-loop to be in ('asyncio', 'tornado', 'uvloop'), got %s"
% dask.config.get("distributed.admin.event-loop")
)


initialize_logging(dask.config.config)
initialize_event_loop(dask.config.config)
5 changes: 3 additions & 2 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from dask.system import CPU_COUNT

from distributed import Client, LocalCluster, Nanny, Worker, get_client
from distributed.compatibility import LINUX
from distributed.compatibility import LINUX, asyncio_run
from distributed.config import get_loop_factory
from distributed.core import Status
from distributed.metrics import time
from distributed.system import MEMORY_LIMIT
Expand Down Expand Up @@ -670,7 +671,7 @@ async def amain():
box = cluster._cached_widget
assert isinstance(box, ipywidgets.Widget)

asyncio.run(amain())
asyncio_run(amain(), loop_factory=get_loop_factory())


def test_no_ipywidgets(loop, monkeypatch):
Expand Down
4 changes: 3 additions & 1 deletion distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from distributed import preloading
from distributed.comm import get_address_host
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.core import (
AsyncTaskGroupClosedError,
CommClosedError,
Expand Down Expand Up @@ -996,7 +998,7 @@ def close_stop_q() -> None:
if silence_logs:
logger.setLevel(silence_logs)

asyncio.run(run())
asyncio_run(run(), loop_factory=get_loop_factory())


def _get_env_variables(config_key: str) -> dict[str, str]:
Expand Down
5 changes: 3 additions & 2 deletions distributed/tests/test_asyncprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import pytest
from tornado.ioloop import IOLoop

from distributed.compatibility import LINUX, MACOS, WINDOWS
from distributed.compatibility import LINUX, MACOS, WINDOWS, asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.process import AsyncProcess
from distributed.utils import get_mp_context, wait_for
Expand Down Expand Up @@ -389,7 +390,7 @@ async def run_with_timeout():
t = asyncio.create_task(parent_process_coroutine())
return await wait_for(t, timeout=10)

asyncio.run(run_with_timeout())
asyncio_run(run_with_timeout(), loop_factory=get_loop_factory())
raise RuntimeError("this should be unreachable due to os._exit")


Expand Down
16 changes: 7 additions & 9 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,15 +1365,13 @@ async def test_update_graph_culls(s, a, b):
assert "z" not in s.tasks


def test_io_loop(loop):
async def main():
with pytest.warns(
DeprecationWarning, match=r"the loop kwarg to Scheduler is deprecated"
):
s = Scheduler(loop=loop, dashboard_address=":0", validate=True)
assert s.io_loop is IOLoop.current()

asyncio.run(main())
@gen_test()
async def test_io_loop(loop):
with pytest.warns(
DeprecationWarning, match=r"the loop kwarg to Scheduler is deprecated"
):
s = Scheduler(loop=loop, dashboard_address=":0", validate=True)
assert s.io_loop is IOLoop.current()


@gen_cluster(client=True)
Expand Down
13 changes: 8 additions & 5 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

import dask

from distributed.compatibility import MACOS, WINDOWS
from distributed.compatibility import MACOS, WINDOWS, asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.utils import (
All,
Expand Down Expand Up @@ -134,7 +135,7 @@ def test_sync_closed_loop():
async def get_loop():
return IOLoop.current()

loop = asyncio.run(get_loop())
loop = asyncio_run(get_loop(), loop_factory=get_loop_factory())
loop.close()

with pytest.raises(RuntimeError) as exc_info:
Expand Down Expand Up @@ -399,7 +400,9 @@ def test_loop_runner(loop_in_thread):
async def make_looprunner_in_async_context():
return IOLoop.current(), LoopRunner()

loop, runner = asyncio.run(make_looprunner_in_async_context())
loop, runner = asyncio_run(
make_looprunner_in_async_context(), loop_factory=get_loop_factory()
)
with pytest.raises(
RuntimeError,
match=r"Accessing the loop property while the loop is not running is not supported",
Expand All @@ -423,7 +426,7 @@ async def make_io_loop_in_async_context():
return IOLoop.current()

# Explicit loop
loop = asyncio.run(make_io_loop_in_async_context())
loop = asyncio_run(make_io_loop_in_async_context(), loop_factory=get_loop_factory())
with pytest.raises(
RuntimeError,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is not supported",
Expand All @@ -449,7 +452,7 @@ async def make_io_loop_in_async_context():
LoopRunner(asynchronous=True)

# Explicit loop
loop = asyncio.run(make_io_loop_in_async_context())
loop = asyncio_run(make_io_loop_in_async_context(), loop_factory=get_loop_factory())
with pytest.raises(
RuntimeError,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is not supported",
Expand Down
9 changes: 5 additions & 4 deletions distributed/tests/test_utils_comm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import asyncio
from unittest import mock

import pytest

from dask.optimization import SubgraphCallable

from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.core import ConnectionPool
from distributed.utils_comm import (
WrappedKey,
Expand Down Expand Up @@ -81,7 +82,7 @@ async def coro():
async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)

assert asyncio.run(f()) is retval
assert asyncio_run(f(), loop_factory=get_loop_factory()) is retval
assert n_calls == 1


Expand All @@ -99,7 +100,7 @@ async def f():
return await retry(coro, count=0, delay_min=-1, delay_max=-1)

with pytest.raises(RuntimeError, match="RT_ERROR 1"):
asyncio.run(f())
asyncio_run(f(), loop_factory=get_loop_factory())

assert n_calls == 1

Expand Down Expand Up @@ -134,7 +135,7 @@ async def f():

with mock.patch("asyncio.sleep", my_sleep):
with pytest.raises(MyEx, match="RT_ERROR 6"):
asyncio.run(f())
asyncio_run(f(), loop_factory=get_loop_factory())

assert n_calls == 6
assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0]
Expand Down
Loading