Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interoperability with asyncio (part 1) #174

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/dispatch/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import asyncio
import functools
import inspect
import signal
import threading


class Runner:
"""Runner is a class similar to asyncio.Runner but that we use for backward
compatibility with Python 3.10 and earlier.
"""

def __init__(self):
self._loop = asyncio.new_event_loop()
self._interrupt_count = 0

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
self.close()

def close(self):
try:
loop = self._loop
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, "shutdown_default_executor"): # Python 3.9+
loop.run_until_complete(loop.shutdown_default_executor())
finally:
loop.close()

def get_loop(self):
return self._loop

def run(self, coro):
if not inspect.iscoroutine(coro):
raise ValueError("a coroutine was expected, got {!r}".format(coro))

try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"Runner.run() cannot be called from a running event loop"
)

task = self._loop.create_task(coro)
sigint_handler = None

if (
threading.current_thread() is threading.main_thread()
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
):
sigint_handler = functools.partial(self._on_sigint, main_task=task)
try:
signal.signal(signal.SIGINT, sigint_handler)
except ValueError:
# `signal.signal` may throw if `threading.main_thread` does
# not support signals (e.g. embedded interpreter with signals
# not registered - see gh-91880)
sigint_handler = None

self._interrupt_count = 0
try:
asyncio.set_event_loop(self._loop)
return self._loop.run_until_complete(task)
except asyncio.CancelledError:
if self._interrupt_count > 0:
uncancel = getattr(task, "uncancel", None)
if uncancel is not None and uncancel() == 0:
raise KeyboardInterrupt()
raise # CancelledError
finally:
asyncio.set_event_loop(None)
if (
sigint_handler is not None
and signal.getsignal(signal.SIGINT) is sigint_handler
):
signal.signal(signal.SIGINT, signal.default_int_handler)

def _on_sigint(self, signum, frame, main_task):
self._interrupt_count += 1
if self._interrupt_count == 1 and not main_task.done():
main_task.cancel()
# wakeup loop if it is blocked by select() with long timeout
self._loop.call_soon_threadsafe(lambda: None)
return
raise KeyboardInterrupt()


def _cancel_all_tasks(loop):
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
4 changes: 3 additions & 1 deletion src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def handler(event, context):

from awslambdaric.lambda_context import LambdaContext

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
Expand Down Expand Up @@ -92,7 +93,8 @@ def handle(

input = Input(req)
try:
output = func._primitive_call(input)
with Runner() as runner:
output = runner.run(func._primitive_call(input))
except Exception:
logger.error("function '%s' fatal error", req.function, exc_info=True)
raise # FIXME
Expand Down
6 changes: 1 addition & 5 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,7 @@ async def execute(request: fastapi.Request):
# forcing execute() to be async.
data: bytes = await request.body()

loop = asyncio.get_running_loop()

content = await loop.run_in_executor(
None,
function_service_run,
content = await function_service_run(
str(request.url),
request.method,
request.headers,
Expand Down
20 changes: 12 additions & 8 deletions src/dispatch/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def read_root():

from flask import Flask, make_response, request

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.http import FunctionServiceError, function_service_run
from dispatch.signature import Ed25519PublicKey, parse_verification_key
Expand Down Expand Up @@ -89,14 +90,17 @@ def _handle_error(self, exc: FunctionServiceError):
def _execute(self):
data: bytes = request.get_data(cache=False)

content = function_service_run(
request.url,
request.method,
dict(request.headers),
data,
self,
self._verification_key,
)
with Runner() as runner:
content = runner.run(
function_service_run(
request.url,
request.method,
dict(request.headers),
data,
self,
self._verification_key,
),
)

res = make_response(content)
res.content_type = "application/proto"
Expand Down
24 changes: 13 additions & 11 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import os
from functools import wraps
from types import CoroutineType
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
Expand All @@ -33,7 +35,7 @@
logger = logging.getLogger(__name__)


PrimitiveFunctionType: TypeAlias = Callable[[Input], Output]
PrimitiveFunctionType: TypeAlias = Callable[[Input], Awaitable[Output]]
"""A primitive function is a function that accepts a dispatch.proto.Input
and unconditionally returns a dispatch.proto.Output. It must not raise
exceptions.
Expand Down Expand Up @@ -70,8 +72,8 @@ def endpoint(self, value: str):
def name(self) -> str:
return self._name

def _primitive_call(self, input: Input) -> Output:
return self._primitive_func(input)
async def _primitive_call(self, input: Input) -> Output:
return await self._primitive_func(input)

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
Expand Down Expand Up @@ -226,6 +228,7 @@ def function(self, func: Callable[P, T]) -> Function[P, T]: ...
def function(self, func):
"""Decorator that registers functions."""
name = func.__qualname__

if not inspect.iscoroutinefunction(func):
logger.info("registering function: %s", name)
return self._register_function(name, func)
Expand All @@ -237,23 +240,22 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]:
func = durable(func)

@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)

async_wrapper.__qualname__ = f"{name}_async"
async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, func, *args, **kwargs)

return self._register_coroutine(name, async_wrapper)
asyncio_wrapper.__qualname__ = f"{name}_asyncio"
return self._register_coroutine(name, asyncio_wrapper)

def _register_coroutine(
self, name: str, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
logger.info("registering coroutine: %s", name)

func = durable(func)

@wraps(func)
def primitive_func(input: Input) -> Output:
return OneShotScheduler(func).run(input)
async def primitive_func(input: Input) -> Output:
return await OneShotScheduler(func).run(input)

primitive_func.__qualname__ = f"{name}_primitive"
durable_primitive_func = durable(primitive_func)
Expand Down
24 changes: 14 additions & 10 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from http_message_signatures import InvalidSignature

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
Expand Down Expand Up @@ -120,14 +121,17 @@ def do_POST(self):
url = self.requestline # TODO: need full URL

try:
content = function_service_run(
url,
method,
dict(self.headers),
data,
self.registry,
self.verification_key,
)
with Runner() as runner:
content = runner.run(
function_service_run(
url,
method,
dict(self.headers),
data,
self.registry,
self.verification_key,
)
)
except FunctionServiceError as e:
return self.send_error_response(e.status, e.code, e.message)

Expand All @@ -137,7 +141,7 @@ def do_POST(self):
self.wfile.write(content)


def function_service_run(
async def function_service_run(
url: str,
method: str,
headers: Mapping[str, str],
Expand Down Expand Up @@ -184,7 +188,7 @@ def function_service_run(
logger.info("running function '%s'", req.function)

try:
output = func._primitive_call(input)
output = await func._primitive_call(input)
except Exception:
# This indicates that an exception was raised in a primitive
# function. Primitive functions must catch exceptions, categorize
Expand Down
Loading