From 8b7da1d69f0178df346edcdb123d1576407e7efc Mon Sep 17 00:00:00 2001
From: Casper van der Wel <casper.vanderwel@nelen-schuurmans.nl>
Date: Mon, 6 Mar 2023 21:30:30 +0100
Subject: [PATCH 1/2] Port AsyncMiddleware

---
 dramatiq/middleware/asyncio.py   | 152 +++++++++++++++++++++++++++++++
 tests/middleware/test_asyncio.py | 111 ++++++++++++++++++++++
 2 files changed, 263 insertions(+)
 create mode 100644 dramatiq/middleware/asyncio.py
 create mode 100644 tests/middleware/test_asyncio.py

diff --git a/dramatiq/middleware/asyncio.py b/dramatiq/middleware/asyncio.py
new file mode 100644
index 00000000..7a170a30
--- /dev/null
+++ b/dramatiq/middleware/asyncio.py
@@ -0,0 +1,152 @@
+import asyncio
+import threading
+import time
+from concurrent.futures import TimeoutError
+from typing import Awaitable, Optional
+
+import dramatiq
+from dramatiq.middleware import Middleware
+
+from ..logging import get_logger
+from .threading import Interrupt
+
+__all__ = ["AsyncActor", "AsyncMiddleware", "async_actor"]
+
+
+class EventLoopThread(threading.Thread):
+    """A thread that runs an asyncio event loop.
+
+    The method 'run_coroutine' should be used to run coroutines from a
+    synchronous context.
+    """
+
+    # seconds to wait for the event loop to start
+    EVENT_LOOP_START_TIMEOUT = 0.1
+    # interval (seconds) to reactivate the worker thread and check
+    # for interrupts
+    INTERRUPT_CHECK_INTERVAL = 1.0
+
+    loop: Optional[asyncio.AbstractEventLoop] = None
+
+    def __init__(self, logger):
+        self.logger = logger
+        super().__init__(target=self._start_event_loop)
+
+    def _start_event_loop(self):
+        """This method should run in the thread"""
+        self.logger.info("Starting the event loop...")
+
+        self.loop = asyncio.new_event_loop()
+        try:
+            self.loop.run_forever()
+        finally:
+            self.loop.close()
+
+    def _stop_event_loop(self):
+        """This method should run outside of the thread"""
+        if self.loop is not None and self.loop.is_running():
+            self.logger.info("Stopping the event loop...")
+            self.loop.call_soon_threadsafe(self.loop.stop)
+
+    def run_coroutine(self, coro: Awaitable) -> None:
+        """To be called from outside the thread
+
+        Blocks until the coroutine is finished.
+        """
+        if self.loop is None or not self.loop.is_running():
+            raise RuntimeError("The event loop is not running")
+        future = asyncio.run_coroutine_threadsafe(coro, self.loop)
+        while True:
+            try:
+                # Use a timeout to be able to catch asynchronously raised dramatiq
+                # exceptions (Interrupt).
+                future.result(timeout=self.INTERRUPT_CHECK_INTERVAL)
+            except Interrupt:
+                # Asynchronously raised from another thread: cancel the future and
+                # reiterate to wait for possible cleanup actions.
+                self.loop.call_soon_threadsafe(future.cancel)
+            except TimeoutError:
+                continue
+
+            break
+
+    def start(self, *args, **kwargs):
+        super().start(*args, **kwargs)
+        time.sleep(self.EVENT_LOOP_START_TIMEOUT)
+        if self.loop is None or not self.loop.is_running():
+            raise RuntimeError("The event loop failed to start")
+        self.logger.info("Event loop is running.")
+
+    def join(self, *args, **kwargs):
+        self._stop_event_loop()
+        return super().join(*args, **kwargs)
+
+
+class AsyncMiddleware(Middleware):
+    """This middleware manages the event loop thread.
+
+    This thread is used to schedule coroutines on from the worker threads.
+    """
+
+    event_loop_thread: Optional[EventLoopThread] = None
+
+    def __init__(self):
+        self.logger = get_logger(__name__, type(self))
+
+    def run_coroutine(self, coro: Awaitable) -> None:
+        self.event_loop_thread.run_coroutine(coro)
+
+    def before_worker_boot(self, broker, worker):
+        self.event_loop_thread = EventLoopThread(self.logger)
+        self.event_loop_thread.start()
+
+        # Monkeypatch the broker to make the event loop thread reachable
+        # from an actor or from other middleware.
+        broker.run_coroutine = self.event_loop_thread.run_coroutine
+
+    def after_worker_shutdown(self, broker, worker):
+        self.event_loop_thread.join()
+        self.event_loop_thread = None
+
+        delattr(broker, "run_coroutine")
+
+
+class AsyncActor(dramatiq.Actor):
+    """To configure coroutines as a dramatiq actor.
+
+    Requires AsyncMiddleware to be active.
+
+    Example usage:
+
+    >>> @dramatiq.actor(..., actor_class=AsyncActor)
+    ... async def my_task(x):
+    ...     print(x)
+
+    Notes:
+
+    The coroutine is scheduled on an event loop that is shared between
+    worker threads. See AsyncMiddleware and EventLoopThread.
+
+    This is compatible with ShutdownNotifications ("notify_shutdown") and
+    TimeLimit ("time_limit"). Both result in an asyncio.CancelledError raised inside
+    the async function. There is currently no way to tell the two apart from
+    within the coroutine.
+    """
+
+    def __init__(self, fn, *args, **kwargs):
+        super().__init__(
+            lambda *args, **kwargs: self.broker.run_coroutine(fn(*args, **kwargs)),
+            *args,
+            **kwargs
+        )
+
+
+def async_actor(awaitable=None, **kwargs):
+    if awaitable is not None:
+        return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
+    else:
+
+        def wrapper(awaitable):
+            return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
+
+        return wrapper
diff --git a/tests/middleware/test_asyncio.py b/tests/middleware/test_asyncio.py
new file mode 100644
index 00000000..b7fab662
--- /dev/null
+++ b/tests/middleware/test_asyncio.py
@@ -0,0 +1,111 @@
+import asyncio
+import threading
+from unittest import mock
+
+import pytest
+
+from dramatiq.middleware.asyncio import AsyncActor, AsyncMiddleware, EventLoopThread, async_actor
+
+
+@pytest.fixture
+def started_thread():
+    thread = EventLoopThread(logger=mock.Mock())
+    thread.start()
+    yield thread
+    thread.join()
+
+
+@pytest.fixture
+def logger():
+    return mock.Mock()
+
+
+def test_event_loop_thread_start():
+    try:
+        thread = EventLoopThread(logger=mock.Mock())
+        thread.start()
+        assert isinstance(thread.loop, asyncio.BaseEventLoop)
+        assert thread.loop.is_running()
+    finally:
+        thread.join()
+
+
+def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread):
+    result = {}
+
+    async def get_thread_id():
+        result["thread_id"] = threading.get_ident()
+
+    started_thread.run_coroutine(get_thread_id())
+
+    # the coroutine executed in the event loop thread
+    assert result["thread_id"] == started_thread.ident
+
+
+def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread):
+    async def raise_error():
+        raise TypeError("bla")
+
+    coro = raise_error()
+
+    with pytest.raises(TypeError, match="bla"):
+        started_thread.run_coroutine(coro)
+
+
+@mock.patch.object(EventLoopThread, "start")
+@mock.patch.object(EventLoopThread, "run_coroutine")
+def test_async_middleware_before_worker_boot(
+    EventLoopThread_run_coroutine, EventLoopThread_start
+):
+    broker = mock.Mock()
+    worker = mock.Mock()
+    middleware = AsyncMiddleware()
+
+    middleware.before_worker_boot(broker, worker)
+
+    assert isinstance(middleware.event_loop_thread, EventLoopThread)
+
+    EventLoopThread_start.assert_called_once()
+
+    middleware.run_coroutine("foo")
+    EventLoopThread_run_coroutine.assert_called_once_with("foo")
+
+    # broker was patched with run_coroutine
+    broker.run_coroutine("bar")
+    EventLoopThread_run_coroutine.assert_called_with("bar")
+
+
+def test_async_middleware_after_worker_shutdown():
+    broker = mock.Mock()
+    broker.run_coroutine = lambda x: x
+    worker = mock.Mock()
+    event_loop_thread = mock.Mock()
+
+    middleware = AsyncMiddleware()
+    middleware.event_loop_thread = event_loop_thread
+    middleware.after_worker_shutdown(broker, worker)
+
+    event_loop_thread.join.assert_called_once()
+    assert middleware.event_loop_thread is None
+    assert not hasattr(broker, "run_coroutine")
+
+
+def test_async_actor(started_thread):
+    broker = mock.Mock()
+    broker.actor_options = {"max_retries"}
+
+    @async_actor(broker=broker)
+    async def foo(*args, **kwargs):
+        pass
+
+    assert isinstance(foo, AsyncActor)
+
+    foo(2, a="b")
+
+    broker.run_coroutine.assert_called_once()
+
+    # no recursion errors here:
+    repr(foo)
+
+    # this is just to stop "never awaited" warnings
+    started_thread.run_coroutine(broker.run_coroutine.call_args[0][0])

From 4483bab45ca164ff50e233187d762cfb5b8e6c01 Mon Sep 17 00:00:00 2001
From: Casper van der Wel <casper.vanderwel@nelen-schuurmans.nl>
Date: Fri, 31 Mar 2023 13:58:51 +0200
Subject: [PATCH 2/2] Refactor

---
 dramatiq/actor.py                |  13 ++--
 dramatiq/middleware/asyncio.py   | 119 +++++++++++++++----------------
 tests/middleware/test_asyncio.py |  93 +++++++++++++-----------
 tests/test_actors.py             |  16 +++++
 4 files changed, 132 insertions(+), 109 deletions(-)

diff --git a/dramatiq/actor.py b/dramatiq/actor.py
index 54adcfce..b2fe214b 100644
--- a/dramatiq/actor.py
+++ b/dramatiq/actor.py
@@ -18,7 +18,8 @@
 
 import re
 import time
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, Union, overload
+from inspect import iscoroutinefunction
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar, Union, overload
 
 from .broker import Broker, get_broker
 from .logging import get_logger
@@ -51,10 +52,9 @@ class Actor(Generic[P, R]):
       options(dict): Arbitrary options that are passed to the broker
         and middleware.
     """
-
     def __init__(
         self,
-        fn: Callable[P, R],
+        fn: Callable[P, Union[R, Awaitable[R]]],
         *,
         broker: Broker,
         actor_name: str,
@@ -63,7 +63,12 @@ def __init__(
         options: Dict[str, Any],
     ) -> None:
         self.logger = get_logger(fn.__module__, actor_name)
-        self.fn = fn
+        if iscoroutinefunction(fn):
+            from dramatiq.middleware.asyncio import async_to_sync
+
+            self.fn = async_to_sync(fn)
+        else:
+            self.fn = fn  # type: ignore
         self.broker = broker
         self.actor_name = actor_name
         self.queue_name = queue_name
diff --git a/dramatiq/middleware/asyncio.py b/dramatiq/middleware/asyncio.py
index 7a170a30..455c696a 100644
--- a/dramatiq/middleware/asyncio.py
+++ b/dramatiq/middleware/asyncio.py
@@ -1,16 +1,63 @@
+from __future__ import annotations
+
 import asyncio
+import functools
 import threading
 import time
 from concurrent.futures import TimeoutError
-from typing import Awaitable, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Optional, TypeVar
 
-import dramatiq
 from dramatiq.middleware import Middleware
 
 from ..logging import get_logger
 from .threading import Interrupt
 
-__all__ = ["AsyncActor", "AsyncMiddleware", "async_actor"]
+if TYPE_CHECKING:
+    from typing_extensions import ParamSpec
+
+    P = ParamSpec("P")
+else:
+    P = TypeVar("P")
+R = TypeVar("R")
+
+__all__ = ["AsyncMiddleware", "async_to_sync"]
+
+# the global event loop thread
+global_event_loop_thread = None
+
+
+def get_event_loop_thread() -> "EventLoopThread":
+    """Get the global event loop thread.
+
+    If no global broker is set, RuntimeError error will be raised.
+
+    Returns:
+      Broker: The global EventLoopThread.
+    """
+    global global_event_loop_thread
+    if global_event_loop_thread is None:
+        raise RuntimeError(
+            "The usage of asyncio in dramatiq requires the AsyncMiddleware "
+            "to be configured."
+        )
+    return global_event_loop_thread
+
+
+def set_event_loop_thread(event_loop_thread: Optional["EventLoopThread"]) -> None:
+    global global_event_loop_thread
+    global_event_loop_thread = event_loop_thread
+
+
+def async_to_sync(async_fn: Callable[P, Awaitable[R]]) -> Callable[P, R]:
+    """Wrap an 'async def' function to make it synchronous."""
+    # assert presence of event loop thread:
+    get_event_loop_thread()
+
+    @functools.wraps(async_fn)
+    def wrapper(*args, **kwargs) -> R:
+        return get_event_loop_thread().run_coroutine(async_fn(*args, **kwargs))
+
+    return wrapper
 
 
 class EventLoopThread(threading.Thread):
@@ -48,7 +95,7 @@ def _stop_event_loop(self):
             self.logger.info("Stopping the event loop...")
             self.loop.call_soon_threadsafe(self.loop.stop)
 
-    def run_coroutine(self, coro: Awaitable) -> None:
+    def run_coroutine(self, coro: Awaitable[R]) -> R:
         """To be called from outside the thread
 
         Blocks until the coroutine is finished.
@@ -60,7 +107,7 @@ def run_coroutine(self, coro: Awaitable) -> None:
             try:
                 # Use a timeout to be able to catch asynchronously raised dramatiq
                 # exceptions (Interrupt).
-                future.result(timeout=self.INTERRUPT_CHECK_INTERVAL)
+                return future.result(timeout=self.INTERRUPT_CHECK_INTERVAL)
             except Interrupt:
                 # Asynchronously raised from another thread: cancel the future and
                 # reiterate to wait for possible cleanup actions.
@@ -68,8 +115,6 @@ def run_coroutine(self, coro: Awaitable) -> None:
             except TimeoutError:
                 continue
 
-            break
-
     def start(self, *args, **kwargs):
         super().start(*args, **kwargs)
         time.sleep(self.EVENT_LOOP_START_TIMEOUT)
@@ -88,65 +133,15 @@ class AsyncMiddleware(Middleware):
     This thread is used to schedule coroutines on from the worker threads.
     """
 
-    event_loop_thread: Optional[EventLoopThread] = None
-
     def __init__(self):
         self.logger = get_logger(__name__, type(self))
 
-    def run_coroutine(self, coro: Awaitable) -> None:
-        self.event_loop_thread.run_coroutine(coro)
-
     def before_worker_boot(self, broker, worker):
-        self.event_loop_thread = EventLoopThread(self.logger)
-        self.event_loop_thread.start()
+        event_loop_thread = EventLoopThread(self.logger)
+        event_loop_thread.start()
 
-        # Monkeypatch the broker to make the event loop thread reachable
-        # from an actor or from other middleware.
-        broker.run_coroutine = self.event_loop_thread.run_coroutine
+        set_event_loop_thread(event_loop_thread)
 
     def after_worker_shutdown(self, broker, worker):
-        self.event_loop_thread.join()
-        self.event_loop_thread = None
-
-        delattr(broker, "run_coroutine")
-
-
-class AsyncActor(dramatiq.Actor):
-    """To configure coroutines as a dramatiq actor.
-
-    Requires AsyncMiddleware to be active.
-
-    Example usage:
-
-    >>> @dramatiq.actor(..., actor_class=AsyncActor)
-    ... async def my_task(x):
-    ...     print(x)
-
-    Notes:
-
-    The coroutine is scheduled on an event loop that is shared between
-    worker threads. See AsyncMiddleware and EventLoopThread.
-
-    This is compatible with ShutdownNotifications ("notify_shutdown") and
-    TimeLimit ("time_limit"). Both result in an asyncio.CancelledError raised inside
-    the async function. There is currently no way to tell the two apart from
-    within the coroutine.
-    """
-
-    def __init__(self, fn, *args, **kwargs):
-        super().__init__(
-            lambda *args, **kwargs: self.broker.run_coroutine(fn(*args, **kwargs)),
-            *args,
-            **kwargs
-        )
-
-
-def async_actor(awaitable=None, **kwargs):
-    if awaitable is not None:
-        return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
-    else:
-
-        def wrapper(awaitable):
-            return dramatiq.actor(awaitable, actor_class=AsyncActor, **kwargs)
-
-        return wrapper
+        get_event_loop_thread().join()
+        set_event_loop_thread(None)
diff --git a/tests/middleware/test_asyncio.py b/tests/middleware/test_asyncio.py
index b7fab662..a9ed9578 100644
--- a/tests/middleware/test_asyncio.py
+++ b/tests/middleware/test_asyncio.py
@@ -4,15 +4,23 @@
 
 import pytest
 
-from dramatiq.middleware.asyncio import AsyncActor, AsyncMiddleware, EventLoopThread, async_actor
+from dramatiq.middleware.asyncio import (
+    AsyncMiddleware,
+    EventLoopThread,
+    async_to_sync,
+    get_event_loop_thread,
+    set_event_loop_thread,
+)
 
 
 @pytest.fixture
 def started_thread():
     thread = EventLoopThread(logger=mock.Mock())
     thread.start()
+    set_event_loop_thread(thread)
     yield thread
     thread.join()
+    set_event_loop_thread(None)
 
 
 @pytest.fixture
@@ -34,12 +42,12 @@ def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread):
     result = {}
 
     async def get_thread_id():
-        result["thread_id"] = threading.get_ident()
+        return threading.get_ident()
 
-    started_thread.run_coroutine(get_thread_id())
+    result = started_thread.run_coroutine(get_thread_id())
 
     # the coroutine executed in the event loop thread
-    assert result["thread_id"] == started_thread.ident
+    assert result == started_thread.ident
 
 
 def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread):
@@ -52,60 +60,59 @@ async def raise_error():
         started_thread.run_coroutine(coro)
 
 
-@mock.patch.object(EventLoopThread, "start")
-@mock.patch.object(EventLoopThread, "run_coroutine")
-def test_async_middleware_before_worker_boot(
-    EventLoopThread_run_coroutine, EventLoopThread_start
-):
-    broker = mock.Mock()
-    worker = mock.Mock()
+@mock.patch("dramatiq.middleware.asyncio.EventLoopThread")
+def test_async_middleware_before_worker_boot(EventLoopThreadMock):
     middleware = AsyncMiddleware()
 
-    middleware.before_worker_boot(broker, worker)
-
-    assert isinstance(middleware.event_loop_thread, EventLoopThread)
-
-    EventLoopThread_start.assert_called_once()
+    try:
+        middleware.before_worker_boot(None, None)
 
-    middleware.run_coroutine("foo")
-    EventLoopThread_run_coroutine.assert_called_once_with("foo")
+        assert get_event_loop_thread() is EventLoopThreadMock.return_value
 
-    # broker was patched with run_coroutine
-    broker.run_coroutine("bar")
-    EventLoopThread_run_coroutine.assert_called_with("bar")
+        EventLoopThreadMock.assert_called_once_with(middleware.logger)
+        EventLoopThreadMock().start.assert_called_once_with()
+    finally:
+        set_event_loop_thread(None)
 
 
 def test_async_middleware_after_worker_shutdown():
-    broker = mock.Mock()
-    broker.run_coroutine = lambda x: x
-    worker = mock.Mock()
+    middleware = AsyncMiddleware()
     event_loop_thread = mock.Mock()
 
-    middleware = AsyncMiddleware()
-    middleware.event_loop_thread = event_loop_thread
-    middleware.after_worker_shutdown(broker, worker)
+    set_event_loop_thread(event_loop_thread)
+
+    try:
+        middleware.after_worker_shutdown(None, None)
+
+        with pytest.raises(RuntimeError):
+            get_event_loop_thread()
+
+        event_loop_thread.join.assert_called_once_with()
+    finally:
+        set_event_loop_thread(None)
+
 
-    event_loop_thread.join.assert_called_once()
-    assert middleware.event_loop_thread is None
-    assert not hasattr(broker, "run_coroutine")
+async def async_fn(value: int = 2) -> int:
+    return value + 1
 
 
-def test_async_actor(started_thread):
-    broker = mock.Mock()
-    broker.actor_options = {"max_retries"}
+@mock.patch("dramatiq.middleware.asyncio.get_event_loop_thread")
+def test_async_to_sync(get_event_loop_thread_mocked):
+    thread = get_event_loop_thread_mocked()
 
-    @async_actor(broker=broker)
-    async def foo(*args, **kwargs):
-        pass
+    fn = async_to_sync(async_fn)
+    actual = fn(2)
+    thread.run_coroutine.assert_called_once()
+    assert actual is thread.run_coroutine()
 
-    assert isinstance(foo, AsyncActor)
 
-    foo(2, a="b")
+@pytest.mark.usefixtures("started_thread")
+def test_async_to_sync_with_actual_thread(started_thread):
+    fn = async_to_sync(async_fn)
 
-    broker.run_coroutine.assert_called_once()
+    assert fn(2) == 3
 
-    # no recursion errors here:
-    repr(foo)
 
-    # this is just to stop "never awaited" warnings
-    started_thread.run_coroutine(broker.run_coroutine.call_args[0][0])
+def test_async_to_sync_no_thread():
+    with pytest.raises(RuntimeError):
+        async_to_sync(async_fn)
diff --git a/tests/test_actors.py b/tests/test_actors.py
index 6d9caf42..7f64c4c8 100644
--- a/tests/test_actors.py
+++ b/tests/test_actors.py
@@ -482,3 +482,19 @@ def accessor(x):
     # When I try to access the current message from a non-worker thread
     # Then I should get back None
     assert CurrentMessage.get_current_message() is None
+
+
+@patch("dramatiq.middleware.asyncio.async_to_sync")
+def test_actors_can_wrap_asyncio(async_to_sync_mock, stub_broker):
+    # Define an asyncio function and wrap it in an actor
+    async def add(x, y):
+        return x + y
+
+    actor = dramatiq.actor(add)
+
+    # I expect that function to become an instance of Actor
+    assert isinstance(actor, dramatiq.Actor)
+
+    # The wrapped function should be wrapped with 'async_to_sync'
+    async_to_sync_mock.assert_called_once_with(add)
+    assert actor.fn == async_to_sync_mock.return_value