From 8b7da1d69f0178df346edcdb123d1576407e7efc Mon Sep 17 00:00:00 2001 From: Casper van der Wel <casper.vanderwel@nelen-schuurmans.nl> Date: Mon, 6 Mar 2023 21:30:30 +0100 Subject: [PATCH 1/2] Port AsyncMiddleware --- dramatiq/middleware/asyncio.py | 152 +++++++++++++++++++++++++++++++ tests/middleware/test_asyncio.py | 111 ++++++++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 dramatiq/middleware/asyncio.py create mode 100644 tests/middleware/test_asyncio.py diff --git a/dramatiq/middleware/asyncio.py b/dramatiq/middleware/asyncio.py new file mode 100644 index 00000000..7a170a30 --- /dev/null +++ b/dramatiq/middleware/asyncio.py @@ -0,0 +1,152 @@ +import asyncio +import threading +import time +from concurrent.futures import TimeoutError +from typing import Awaitable, Optional + +import dramatiq +from dramatiq.middleware import Middleware + +from ..logging import get_logger +from .threading import Interrupt + +__all__ = ["AsyncActor", "AsyncMiddleware", "async_actor"] + + +class EventLoopThread(threading.Thread): + """A thread that runs an asyncio event loop. + + The method 'run_coroutine' should be used to run coroutines from a + synchronous context. + """ + + # seconds to wait for the event loop to start + EVENT_LOOP_START_TIMEOUT = 0.1 + # interval (seconds) to reactivate the worker thread and check + # for interrupts + INTERRUPT_CHECK_INTERVAL = 1.0 + + loop: Optional[asyncio.AbstractEventLoop] = None + + def __init__(self, logger): + self.logger = logger + super().__init__(target=self._start_event_loop) + + def _start_event_loop(self): + """This method should run in the thread""" + self.logger.info("Starting the event loop...") + + self.loop = asyncio.new_event_loop() + try: + self.loop.run_forever() + finally: + self.loop.close() + + def _stop_event_loop(self): + """This method should run outside of the thread""" + if self.loop is not None and self.loop.is_running(): + self.logger.info("Stopping the event loop...") + self.loop.call_soon_threadsafe(self.loop.stop) + + def run_coroutine(self, coro: Awaitable) -> None: + """To be called from outside the thread + + Blocks until the coroutine is finished. + """ + if self.loop is None or not self.loop.is_running(): + raise RuntimeError("The event loop is not running") + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + while True: + try: + # Use a timeout to be able to catch asynchronously raised dramatiq + # exceptions (Interrupt). + future.result(timeout=self.INTERRUPT_CHECK_INTERVAL) + except Interrupt: + # Asynchronously raised from another thread: cancel the future and + # reiterate to wait for possible cleanup actions. + self.loop.call_soon_threadsafe(future.cancel) + except TimeoutError: + continue + + break + + def start(self, *args, **kwargs): + super().start(*args, **kwargs) + time.sleep(self.EVENT_LOOP_START_TIMEOUT) + if self.loop is None or not self.loop.is_running(): + raise RuntimeError("The event loop failed to start") + self.logger.info("Event loop is running.") + + def join(self, *args, **kwargs): + self._stop_event_loop() + return super().join(*args, **kwargs) + + +class AsyncMiddleware(Middleware): + """This middleware manages the event loop thread. + + This thread is used to schedule coroutines on from the worker threads. + """ + + event_loop_thread: Optional[EventLoopThread] = None + + def __init__(self): + self.logger = get_logger(__name__, type(self)) + + def run_coroutine(self, coro: Awaitable) -> None: + self.event_loop_thread.run_coroutine(coro) + + def before_worker_boot(self, broker, worker): + self.event_loop_thread = EventLoopThread(self.logger) + self.event_loop_thread.start() + + # Monkeypatch the broker to make the event loop thread reachable + # from an actor or from other middleware. + broker.run_coroutine = self.event_loop_thread.run_coroutine + + def after_worker_shutdown(self, broker, worker): + self.event_loop_thread.join() + self.event_loop_thread = None + + delattr(broker, "run_coroutine") + + +class AsyncActor(dramatiq.Actor): + """To configure coroutines as a dramatiq actor. + + Requires AsyncMiddleware to be active. + + Example usage: + + >>> @dramatiq.actor(..., actor_class=AsyncActor) + ... async def my_task(x): + ... print(x) + + Notes: + + The coroutine is scheduled on an event loop that is shared between + worker threads. See AsyncMiddleware and EventLoopThread. + + This is compatible with ShutdownNotifications ("notify_shutdown") and + TimeLimit ("time_limit"). Both result in an asyncio.CancelledError raised inside + the async function. There is currently no way to tell the two apart from + within the coroutine. + """ + + def __init__(self, fn, *args, **kwargs): + super().__init__( + lambda *args, **kwargs: self.broker.run_coroutine(fn(*args, **kwargs)), + *args, + **kwargs + ) + + +def async_actor(awaitable=None, **kwargs): + if awaitable is not None: + return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs) + else: + + def wrapper(awaitable): + return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs) + + return wrapper diff --git a/tests/middleware/test_asyncio.py b/tests/middleware/test_asyncio.py new file mode 100644 index 00000000..b7fab662 --- /dev/null +++ b/tests/middleware/test_asyncio.py @@ -0,0 +1,111 @@ +import asyncio +import threading +from unittest import mock + +import pytest + +from dramatiq.middleware.asyncio import AsyncActor, AsyncMiddleware, EventLoopThread, async_actor + + +@pytest.fixture +def started_thread(): + thread = EventLoopThread(logger=mock.Mock()) + thread.start() + yield thread + thread.join() + + +@pytest.fixture +def logger(): + return mock.Mock() + + +def test_event_loop_thread_start(): + try: + thread = EventLoopThread(logger=mock.Mock()) + thread.start() + assert isinstance(thread.loop, asyncio.BaseEventLoop) + assert thread.loop.is_running() + finally: + thread.join() + + +def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread): + result = {} + + async def get_thread_id(): + result["thread_id"] = threading.get_ident() + + started_thread.run_coroutine(get_thread_id()) + + # the coroutine executed in the event loop thread + assert result["thread_id"] == started_thread.ident + + +def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread): + async def raise_error(): + raise TypeError("bla") + + coro = raise_error() + + with pytest.raises(TypeError, match="bla"): + started_thread.run_coroutine(coro) + + +@mock.patch.object(EventLoopThread, "start") +@mock.patch.object(EventLoopThread, "run_coroutine") +def test_async_middleware_before_worker_boot( + EventLoopThread_run_coroutine, EventLoopThread_start +): + broker = mock.Mock() + worker = mock.Mock() + middleware = AsyncMiddleware() + + middleware.before_worker_boot(broker, worker) + + assert isinstance(middleware.event_loop_thread, EventLoopThread) + + EventLoopThread_start.assert_called_once() + + middleware.run_coroutine("foo") + EventLoopThread_run_coroutine.assert_called_once_with("foo") + + # broker was patched with run_coroutine + broker.run_coroutine("bar") + EventLoopThread_run_coroutine.assert_called_with("bar") + + +def test_async_middleware_after_worker_shutdown(): + broker = mock.Mock() + broker.run_coroutine = lambda x: x + worker = mock.Mock() + event_loop_thread = mock.Mock() + + middleware = AsyncMiddleware() + middleware.event_loop_thread = event_loop_thread + middleware.after_worker_shutdown(broker, worker) + + event_loop_thread.join.assert_called_once() + assert middleware.event_loop_thread is None + assert not hasattr(broker, "run_coroutine") + + +def test_async_actor(started_thread): + broker = mock.Mock() + broker.actor_options = {"max_retries"} + + @async_actor(broker=broker) + async def foo(*args, **kwargs): + pass + + assert isinstance(foo, AsyncActor) + + foo(2, a="b") + + broker.run_coroutine.assert_called_once() + + # no recursion errors here: + repr(foo) + + # this is just to stop "never awaited" warnings + started_thread.run_coroutine(broker.run_coroutine.call_args[0][0]) From 4483bab45ca164ff50e233187d762cfb5b8e6c01 Mon Sep 17 00:00:00 2001 From: Casper van der Wel <casper.vanderwel@nelen-schuurmans.nl> Date: Fri, 31 Mar 2023 13:58:51 +0200 Subject: [PATCH 2/2] Refactor --- dramatiq/actor.py | 13 ++-- dramatiq/middleware/asyncio.py | 119 +++++++++++++++---------------- tests/middleware/test_asyncio.py | 93 +++++++++++++----------- tests/test_actors.py | 16 +++++ 4 files changed, 132 insertions(+), 109 deletions(-) diff --git a/dramatiq/actor.py b/dramatiq/actor.py index 54adcfce..b2fe214b 100644 --- a/dramatiq/actor.py +++ b/dramatiq/actor.py @@ -18,7 +18,8 @@ import re import time -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, Union, overload +from inspect import iscoroutinefunction +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar, Union, overload from .broker import Broker, get_broker from .logging import get_logger @@ -51,10 +52,9 @@ class Actor(Generic[P, R]): options(dict): Arbitrary options that are passed to the broker and middleware. """ - def __init__( self, - fn: Callable[P, R], + fn: Callable[P, Union[R, Awaitable[R]]], *, broker: Broker, actor_name: str, @@ -63,7 +63,12 @@ def __init__( options: Dict[str, Any], ) -> None: self.logger = get_logger(fn.__module__, actor_name) - self.fn = fn + if iscoroutinefunction(fn): + from dramatiq.middleware.asyncio import async_to_sync + + self.fn = async_to_sync(fn) + else: + self.fn = fn # type: ignore self.broker = broker self.actor_name = actor_name self.queue_name = queue_name diff --git a/dramatiq/middleware/asyncio.py b/dramatiq/middleware/asyncio.py index 7a170a30..455c696a 100644 --- a/dramatiq/middleware/asyncio.py +++ b/dramatiq/middleware/asyncio.py @@ -1,16 +1,63 @@ +from __future__ import annotations + import asyncio +import functools import threading import time from concurrent.futures import TimeoutError -from typing import Awaitable, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Optional, TypeVar -import dramatiq from dramatiq.middleware import Middleware from ..logging import get_logger from .threading import Interrupt -__all__ = ["AsyncActor", "AsyncMiddleware", "async_actor"] +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") +else: + P = TypeVar("P") +R = TypeVar("R") + +__all__ = ["AsyncMiddleware", "async_to_sync"] + +# the global event loop thread +global_event_loop_thread = None + + +def get_event_loop_thread() -> "EventLoopThread": + """Get the global event loop thread. + + If no global broker is set, RuntimeError error will be raised. + + Returns: + Broker: The global EventLoopThread. + """ + global global_event_loop_thread + if global_event_loop_thread is None: + raise RuntimeError( + "The usage of asyncio in dramatiq requires the AsyncMiddleware " + "to be configured." + ) + return global_event_loop_thread + + +def set_event_loop_thread(event_loop_thread: Optional["EventLoopThread"]) -> None: + global global_event_loop_thread + global_event_loop_thread = event_loop_thread + + +def async_to_sync(async_fn: Callable[P, Awaitable[R]]) -> Callable[P, R]: + """Wrap an 'async def' function to make it synchronous.""" + # assert presence of event loop thread: + get_event_loop_thread() + + @functools.wraps(async_fn) + def wrapper(*args, **kwargs) -> R: + return get_event_loop_thread().run_coroutine(async_fn(*args, **kwargs)) + + return wrapper class EventLoopThread(threading.Thread): @@ -48,7 +95,7 @@ def _stop_event_loop(self): self.logger.info("Stopping the event loop...") self.loop.call_soon_threadsafe(self.loop.stop) - def run_coroutine(self, coro: Awaitable) -> None: + def run_coroutine(self, coro: Awaitable[R]) -> R: """To be called from outside the thread Blocks until the coroutine is finished. @@ -60,7 +107,7 @@ def run_coroutine(self, coro: Awaitable) -> None: try: # Use a timeout to be able to catch asynchronously raised dramatiq # exceptions (Interrupt). - future.result(timeout=self.INTERRUPT_CHECK_INTERVAL) + return future.result(timeout=self.INTERRUPT_CHECK_INTERVAL) except Interrupt: # Asynchronously raised from another thread: cancel the future and # reiterate to wait for possible cleanup actions. @@ -68,8 +115,6 @@ def run_coroutine(self, coro: Awaitable) -> None: except TimeoutError: continue - break - def start(self, *args, **kwargs): super().start(*args, **kwargs) time.sleep(self.EVENT_LOOP_START_TIMEOUT) @@ -88,65 +133,15 @@ class AsyncMiddleware(Middleware): This thread is used to schedule coroutines on from the worker threads. """ - event_loop_thread: Optional[EventLoopThread] = None - def __init__(self): self.logger = get_logger(__name__, type(self)) - def run_coroutine(self, coro: Awaitable) -> None: - self.event_loop_thread.run_coroutine(coro) - def before_worker_boot(self, broker, worker): - self.event_loop_thread = EventLoopThread(self.logger) - self.event_loop_thread.start() + event_loop_thread = EventLoopThread(self.logger) + event_loop_thread.start() - # Monkeypatch the broker to make the event loop thread reachable - # from an actor or from other middleware. - broker.run_coroutine = self.event_loop_thread.run_coroutine + set_event_loop_thread(event_loop_thread) def after_worker_shutdown(self, broker, worker): - self.event_loop_thread.join() - self.event_loop_thread = None - - delattr(broker, "run_coroutine") - - -class AsyncActor(dramatiq.Actor): - """To configure coroutines as a dramatiq actor. - - Requires AsyncMiddleware to be active. - - Example usage: - - >>> @dramatiq.actor(..., actor_class=AsyncActor) - ... async def my_task(x): - ... print(x) - - Notes: - - The coroutine is scheduled on an event loop that is shared between - worker threads. See AsyncMiddleware and EventLoopThread. - - This is compatible with ShutdownNotifications ("notify_shutdown") and - TimeLimit ("time_limit"). Both result in an asyncio.CancelledError raised inside - the async function. There is currently no way to tell the two apart from - within the coroutine. - """ - - def __init__(self, fn, *args, **kwargs): - super().__init__( - lambda *args, **kwargs: self.broker.run_coroutine(fn(*args, **kwargs)), - *args, - **kwargs - ) - - -def async_actor(awaitable=None, **kwargs): - if awaitable is not None: - return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs) - else: - - def wrapper(awaitable): - return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs) - - return wrapper + get_event_loop_thread().join() + set_event_loop_thread(None) diff --git a/tests/middleware/test_asyncio.py b/tests/middleware/test_asyncio.py index b7fab662..a9ed9578 100644 --- a/tests/middleware/test_asyncio.py +++ b/tests/middleware/test_asyncio.py @@ -4,15 +4,23 @@ import pytest -from dramatiq.middleware.asyncio import AsyncActor, AsyncMiddleware, EventLoopThread, async_actor +from dramatiq.middleware.asyncio import ( + AsyncMiddleware, + EventLoopThread, + async_to_sync, + get_event_loop_thread, + set_event_loop_thread, +) @pytest.fixture def started_thread(): thread = EventLoopThread(logger=mock.Mock()) thread.start() + set_event_loop_thread(thread) yield thread thread.join() + set_event_loop_thread(None) @pytest.fixture @@ -34,12 +42,12 @@ def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread): result = {} async def get_thread_id(): - result["thread_id"] = threading.get_ident() + return threading.get_ident() - started_thread.run_coroutine(get_thread_id()) + result = started_thread.run_coroutine(get_thread_id()) # the coroutine executed in the event loop thread - assert result["thread_id"] == started_thread.ident + assert result == started_thread.ident def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread): @@ -52,60 +60,59 @@ async def raise_error(): started_thread.run_coroutine(coro) -@mock.patch.object(EventLoopThread, "start") -@mock.patch.object(EventLoopThread, "run_coroutine") -def test_async_middleware_before_worker_boot( - EventLoopThread_run_coroutine, EventLoopThread_start -): - broker = mock.Mock() - worker = mock.Mock() +@mock.patch("dramatiq.middleware.asyncio.EventLoopThread") +def test_async_middleware_before_worker_boot(EventLoopThreadMock): middleware = AsyncMiddleware() - middleware.before_worker_boot(broker, worker) - - assert isinstance(middleware.event_loop_thread, EventLoopThread) - - EventLoopThread_start.assert_called_once() + try: + middleware.before_worker_boot(None, None) - middleware.run_coroutine("foo") - EventLoopThread_run_coroutine.assert_called_once_with("foo") + assert get_event_loop_thread() is EventLoopThreadMock.return_value - # broker was patched with run_coroutine - broker.run_coroutine("bar") - EventLoopThread_run_coroutine.assert_called_with("bar") + EventLoopThreadMock.assert_called_once_with(middleware.logger) + EventLoopThreadMock().start.assert_called_once_with() + finally: + set_event_loop_thread(None) def test_async_middleware_after_worker_shutdown(): - broker = mock.Mock() - broker.run_coroutine = lambda x: x - worker = mock.Mock() + middleware = AsyncMiddleware() event_loop_thread = mock.Mock() - middleware = AsyncMiddleware() - middleware.event_loop_thread = event_loop_thread - middleware.after_worker_shutdown(broker, worker) + set_event_loop_thread(event_loop_thread) + + try: + middleware.after_worker_shutdown(None, None) + + with pytest.raises(RuntimeError): + get_event_loop_thread() + + event_loop_thread.join.assert_called_once_with() + finally: + set_event_loop_thread(None) + - event_loop_thread.join.assert_called_once() - assert middleware.event_loop_thread is None - assert not hasattr(broker, "run_coroutine") +async def async_fn(value: int = 2) -> int: + return value + 1 -def test_async_actor(started_thread): - broker = mock.Mock() - broker.actor_options = {"max_retries"} +@mock.patch("dramatiq.middleware.asyncio.get_event_loop_thread") +def test_async_to_sync(get_event_loop_thread_mocked): + thread = get_event_loop_thread_mocked() - @async_actor(broker=broker) - async def foo(*args, **kwargs): - pass + fn = async_to_sync(async_fn) + actual = fn(2) + thread.run_coroutine.assert_called_once() + assert actual is thread.run_coroutine() - assert isinstance(foo, AsyncActor) - foo(2, a="b") +@pytest.mark.usefixtures("started_thread") +def test_async_to_sync_with_actual_thread(started_thread): + fn = async_to_sync(async_fn) - broker.run_coroutine.assert_called_once() + assert fn(2) == 3 - # no recursion errors here: - repr(foo) - # this is just to stop "never awaited" warnings - started_thread.run_coroutine(broker.run_coroutine.call_args[0][0]) +def test_async_to_sync_no_thread(): + with pytest.raises(RuntimeError): + async_to_sync(async_fn) diff --git a/tests/test_actors.py b/tests/test_actors.py index 6d9caf42..7f64c4c8 100644 --- a/tests/test_actors.py +++ b/tests/test_actors.py @@ -482,3 +482,19 @@ def accessor(x): # When I try to access the current message from a non-worker thread # Then I should get back None assert CurrentMessage.get_current_message() is None + + +@patch("dramatiq.middleware.asyncio.async_to_sync") +def test_actors_can_wrap_asyncio(async_to_sync_mock, stub_broker): + # Define an asyncio function and wrap it in an actor + async def add(x, y): + return x + y + + actor = dramatiq.actor(add) + + # I expect that function to become an instance of Actor + assert isinstance(actor, dramatiq.Actor) + + # The wrapped function should be wrapped with 'async_to_sync' + async_to_sync_mock.assert_called_once_with(add) + assert actor.fn == async_to_sync_mock.return_value