Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support asynchronous (asyncio) contexts in actors #536

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions dramatiq/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import re
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, Union, overload
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar, Union, overload

from .broker import Broker, get_broker
from .logging import get_logger
Expand Down Expand Up @@ -51,10 +52,9 @@ class Actor(Generic[P, R]):
options(dict): Arbitrary options that are passed to the broker
and middleware.
"""

def __init__(
self,
fn: Callable[P, R],
fn: Callable[P, Union[R, Awaitable[R]]],
*,
broker: Broker,
actor_name: str,
Expand All @@ -63,7 +63,12 @@ def __init__(
options: Dict[str, Any],
) -> None:
self.logger = get_logger(fn.__module__, actor_name)
self.fn = fn
if iscoroutinefunction(fn):
from dramatiq.middleware.asyncio import async_to_sync

self.fn = async_to_sync(fn)
else:
self.fn = fn # type: ignore
self.broker = broker
self.actor_name = actor_name
self.queue_name = queue_name
Expand Down
147 changes: 147 additions & 0 deletions dramatiq/middleware/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import asyncio
import functools
import threading
import time
from concurrent.futures import TimeoutError
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, TypeVar

from dramatiq.middleware import Middleware

from ..logging import get_logger
from .threading import Interrupt

if TYPE_CHECKING:
from typing_extensions import ParamSpec

P = ParamSpec("P")
else:
P = TypeVar("P")
R = TypeVar("R")

__all__ = ["AsyncMiddleware", "async_to_sync"]

# the global event loop thread
global_event_loop_thread = None


def get_event_loop_thread() -> "EventLoopThread":
"""Get the global event loop thread.

If no global broker is set, RuntimeError error will be raised.

Returns:
Broker: The global EventLoopThread.
"""
global global_event_loop_thread
if global_event_loop_thread is None:
raise RuntimeError(
"The usage of asyncio in dramatiq requires the AsyncMiddleware "
"to be configured."
)
return global_event_loop_thread


def set_event_loop_thread(event_loop_thread: Optional["EventLoopThread"]) -> None:
global global_event_loop_thread
global_event_loop_thread = event_loop_thread


def async_to_sync(async_fn: Callable[P, Awaitable[R]]) -> Callable[P, R]:
"""Wrap an 'async def' function to make it synchronous."""
# assert presence of event loop thread:
get_event_loop_thread()

@functools.wraps(async_fn)
def wrapper(*args, **kwargs) -> R:
return get_event_loop_thread().run_coroutine(async_fn(*args, **kwargs))

return wrapper


class EventLoopThread(threading.Thread):
"""A thread that runs an asyncio event loop.

The method 'run_coroutine' should be used to run coroutines from a
synchronous context.
"""

# seconds to wait for the event loop to start
EVENT_LOOP_START_TIMEOUT = 0.1
# interval (seconds) to reactivate the worker thread and check
# for interrupts
INTERRUPT_CHECK_INTERVAL = 1.0

loop: Optional[asyncio.AbstractEventLoop] = None

def __init__(self, logger):
self.logger = logger
super().__init__(target=self._start_event_loop)

def _start_event_loop(self):
"""This method should run in the thread"""
self.logger.info("Starting the event loop...")

self.loop = asyncio.new_event_loop()
try:
self.loop.run_forever()
finally:
self.loop.close()

def _stop_event_loop(self):
"""This method should run outside of the thread"""
if self.loop is not None and self.loop.is_running():
self.logger.info("Stopping the event loop...")
self.loop.call_soon_threadsafe(self.loop.stop)

def run_coroutine(self, coro: Awaitable[R]) -> R:
"""To be called from outside the thread

Blocks until the coroutine is finished.
"""
if self.loop is None or not self.loop.is_running():
raise RuntimeError("The event loop is not running")
future = asyncio.run_coroutine_threadsafe(coro, self.loop)
while True:
try:
# Use a timeout to be able to catch asynchronously raised dramatiq
# exceptions (Interrupt).
return future.result(timeout=self.INTERRUPT_CHECK_INTERVAL)
except Interrupt:
# Asynchronously raised from another thread: cancel the future and
# reiterate to wait for possible cleanup actions.
self.loop.call_soon_threadsafe(future.cancel)
except TimeoutError:
continue

def start(self, *args, **kwargs):
super().start(*args, **kwargs)
time.sleep(self.EVENT_LOOP_START_TIMEOUT)
if self.loop is None or not self.loop.is_running():
raise RuntimeError("The event loop failed to start")
self.logger.info("Event loop is running.")

def join(self, *args, **kwargs):
self._stop_event_loop()
return super().join(*args, **kwargs)


class AsyncMiddleware(Middleware):
"""This middleware manages the event loop thread.

This thread is used to schedule coroutines on from the worker threads.
"""

def __init__(self):
self.logger = get_logger(__name__, type(self))

def before_worker_boot(self, broker, worker):
event_loop_thread = EventLoopThread(self.logger)
event_loop_thread.start()

set_event_loop_thread(event_loop_thread)

def after_worker_shutdown(self, broker, worker):
get_event_loop_thread().join()
set_event_loop_thread(None)
118 changes: 118 additions & 0 deletions tests/middleware/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import asyncio
import threading
from unittest import mock

import pytest

from dramatiq.middleware.asyncio import (
AsyncMiddleware,
EventLoopThread,
async_to_sync,
get_event_loop_thread,
set_event_loop_thread,
)


@pytest.fixture
def started_thread():
thread = EventLoopThread(logger=mock.Mock())
thread.start()
set_event_loop_thread(thread)
yield thread
thread.join()
set_event_loop_thread(None)


@pytest.fixture
def logger():
return mock.Mock()


def test_event_loop_thread_start():
try:
thread = EventLoopThread(logger=mock.Mock())
thread.start()
assert isinstance(thread.loop, asyncio.BaseEventLoop)
assert thread.loop.is_running()
finally:
thread.join()


def test_event_loop_thread_run_coroutine(started_thread: EventLoopThread):
result = {}

async def get_thread_id():
return threading.get_ident()

result = started_thread.run_coroutine(get_thread_id())

# the coroutine executed in the event loop thread
assert result == started_thread.ident


def test_event_loop_thread_run_coroutine_exception(started_thread: EventLoopThread):
async def raise_error():
raise TypeError("bla")

coro = raise_error()

with pytest.raises(TypeError, match="bla"):
started_thread.run_coroutine(coro)


@mock.patch("dramatiq.middleware.asyncio.EventLoopThread")
def test_async_middleware_before_worker_boot(EventLoopThreadMock):
middleware = AsyncMiddleware()

try:
middleware.before_worker_boot(None, None)

assert get_event_loop_thread() is EventLoopThreadMock.return_value

EventLoopThreadMock.assert_called_once_with(middleware.logger)
EventLoopThreadMock().start.assert_called_once_with()
finally:
set_event_loop_thread(None)


def test_async_middleware_after_worker_shutdown():
middleware = AsyncMiddleware()
event_loop_thread = mock.Mock()

set_event_loop_thread(event_loop_thread)

try:
middleware.after_worker_shutdown(None, None)

with pytest.raises(RuntimeError):
get_event_loop_thread()

event_loop_thread.join.assert_called_once_with()
finally:
set_event_loop_thread(None)


async def async_fn(value: int = 2) -> int:
return value + 1


@mock.patch("dramatiq.middleware.asyncio.get_event_loop_thread")
def test_async_to_sync(get_event_loop_thread_mocked):
thread = get_event_loop_thread_mocked()

fn = async_to_sync(async_fn)
actual = fn(2)
thread.run_coroutine.assert_called_once()
assert actual is thread.run_coroutine()


@pytest.mark.usefixtures("started_thread")
def test_async_to_sync_with_actual_thread(started_thread):
fn = async_to_sync(async_fn)

assert fn(2) == 3


def test_async_to_sync_no_thread():
with pytest.raises(RuntimeError):
async_to_sync(async_fn)
16 changes: 16 additions & 0 deletions tests/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,19 @@ def accessor(x):
# When I try to access the current message from a non-worker thread
# Then I should get back None
assert CurrentMessage.get_current_message() is None


@patch("dramatiq.middleware.asyncio.async_to_sync")
def test_actors_can_wrap_asyncio(async_to_sync_mock, stub_broker):
# Define an asyncio function and wrap it in an actor
async def add(x, y):
return x + y

actor = dramatiq.actor(add)

# I expect that function to become an instance of Actor
assert isinstance(actor, dramatiq.Actor)

# The wrapped function should be wrapped with 'async_to_sync'
async_to_sync_mock.assert_called_once_with(add)
assert actor.fn == async_to_sync_mock.return_value