From 4a34c1caa383dffc8bc8d331f8005b69d420b97a Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 7 Jun 2024 16:24:47 -0700 Subject: [PATCH 01/34] aiohttp: refactor internals to use asyncio throughout the SDK Signed-off-by: Achille Roussel --- pyproject.toml | 3 +- src/dispatch/__init__.py | 4 +- src/dispatch/aiohttp.py | 131 --- src/dispatch/config.py | 52 ++ src/dispatch/experimental/lambda_handler.py | 2 +- src/dispatch/function.py | 281 ++++--- src/dispatch/http.py | 198 +++-- src/dispatch/test/__init__.py | 262 +++++- src/dispatch/test/http.py | 4 + tests/dispatch/test_config.py | 44 + tests/dispatch/test_error.py | 126 +-- tests/dispatch/test_function.py | 23 +- tests/dispatch/test_scheduler.py | 869 ++++++++++---------- tests/dispatch/test_status.py | 366 +++++---- tests/test_client.py | 129 +-- tests/test_fastapi.py | 31 +- tests/test_flask.py | 13 +- tests/test_http.py | 128 +-- 18 files changed, 1530 insertions(+), 1136 deletions(-) delete mode 100644 src/dispatch/aiohttp.py create mode 100644 src/dispatch/config.py create mode 100644 tests/dispatch/test_config.py diff --git a/pyproject.toml b/pyproject.toml index d3ec2c73..2aff75b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dev = [ "black >= 24.1.0", "isort >= 5.13.2", "mypy >= 1.10.0", - "pytest==8.0.0", + "pytest >= 8.0.0", + "pytest-asyncio >= 0.23.7", "fastapi >= 0.109.0", "coverage >= 7.4.1", "requests >= 2.31.0", diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 65af8bf5..812d621a 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -12,7 +12,7 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race -from dispatch.function import DEFAULT_API_URL, Batch, Client, Function, Registry, Reset +from dispatch.function import Batch, Client, ClientError, Function, Registry, Reset from dispatch.http import Dispatch from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output @@ -21,7 +21,7 @@ __all__ = [ "Call", "Client", - "DEFAULT_API_URL", + "ClientError", "DispatchID", "Error", "Input", diff --git a/src/dispatch/aiohttp.py b/src/dispatch/aiohttp.py deleted file mode 100644 index f2b42447..00000000 --- a/src/dispatch/aiohttp.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Optional, Union - -from aiohttp import web - -from dispatch.function import Registry -from dispatch.http import ( - FunctionServiceError, - function_service_run, - make_error_response_body, -) -from dispatch.signature import Ed25519PublicKey, parse_verification_key - - -class Dispatch(web.Application): - """A Dispatch instance servicing as a http server.""" - - registry: Registry - verification_key: Optional[Ed25519PublicKey] - - def __init__( - self, - registry: Registry, - verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - ): - """Initialize a Dispatch application. - - Args: - registry: The registry of functions to be serviced. - - verification_key: The verification key to use for requests. - """ - super().__init__() - self.registry = registry - self.verification_key = parse_verification_key(verification_key) - self.add_routes( - [ - web.post( - "/dispatch.sdk.v1.FunctionService/Run", self.handle_run_request - ), - ] - ) - - async def handle_run_request(self, request: web.Request) -> web.Response: - return await function_service_run_handler( - request, self.registry, self.verification_key - ) - - -class Server: - host: str - port: int - app: Dispatch - - _runner: web.AppRunner - _site: web.TCPSite - - def __init__(self, host: str, port: int, app: Dispatch): - self.host = host - self.port = port - self.app = app - - async def __aenter__(self): - await self.start() - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.stop() - - async def start(self): - self._runner = web.AppRunner(self.app) - await self._runner.setup() - - self._site = web.TCPSite(self._runner, self.host, self.port) - await self._site.start() - - async def stop(self): - await self._site.stop() - await self._runner.cleanup() - - -def make_error_response(status: int, code: str, message: str) -> web.Response: - body = make_error_response_body(code, message) - return web.Response(status=status, content_type="application/json", body=body) - - -def make_error_response_invalid_argument(message: str) -> web.Response: - return make_error_response(400, "invalid_argument", message) - - -def make_error_response_not_found(message: str) -> web.Response: - return make_error_response(404, "not_found", message) - - -def make_error_response_unauthenticated(message: str) -> web.Response: - return make_error_response(401, "unauthenticated", message) - - -def make_error_response_permission_denied(message: str) -> web.Response: - return make_error_response(403, "permission_denied", message) - - -def make_error_response_internal(message: str) -> web.Response: - return make_error_response(500, "internal", message) - - -async def function_service_run_handler( - request: web.Request, - function_registry: Registry, - verification_key: Optional[Ed25519PublicKey], -) -> web.Response: - content_length = request.content_length - if content_length is None or content_length == 0: - return make_error_response_invalid_argument("content length is required") - if content_length < 0: - return make_error_response_invalid_argument("content length is negative") - if content_length > 16_000_000: - return make_error_response_invalid_argument("content length is too large") - - data: bytes = await request.read() - try: - content = await function_service_run( - str(request.url), - request.method, - dict(request.headers), - data, - function_registry, - verification_key, - ) - except FunctionServiceError as e: - return make_error_response(e.status, e.code, e.message) - return web.Response(status=200, content_type="application/proto", body=content) diff --git a/src/dispatch/config.py b/src/dispatch/config.py new file mode 100644 index 00000000..c5010fae --- /dev/null +++ b/src/dispatch/config.py @@ -0,0 +1,52 @@ +import os +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class NamedValueFromEnvironment: + _envvar: str + _name: str + _value: str + _from_envvar: bool + + def __init__( + self, + envvar: str, + name: str, + value: Optional[str] = None, + from_envvar: bool = False, + ): + self._envvar = envvar + self._name = name + self._from_envvar = from_envvar + if value is None: + self._value = os.environ.get(envvar) or "" + self._from_envvar = True + else: + self._value = value + + def __str__(self): + return self.value + + def __getstate__(self): + return (self._envvar, self._name, self._value, self._from_envvar) + + def __setstate__(self, state): + (self._envvar, self._name, self._value, self._from_envvar) = state + if self._from_envvar: + self._value = os.environ.get(self._envvar) or "" + self._from_envvar = True + + @property + def name(self) -> str: + return self._envvar if self._from_envvar else self._name + + @property + def value(self) -> str: + return self._value + + @value.setter + def value(self, value: str): + self._value = value + self._from_envvar = False diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 5e54a48c..6aeeaca6 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -1,4 +1,4 @@ -"""Integration of Dispatch programmable endpoints for FastAPI. +"""Integration of Dispatch programmable endpoints for AWS Lambda. Example: diff --git a/src/dispatch/function.py b/src/dispatch/function.py index e9f7a321..75302c78 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -2,8 +2,10 @@ import asyncio import inspect +import json import logging import os +import threading from functools import wraps from types import CoroutineType from typing import ( @@ -16,25 +18,66 @@ Iterable, List, Optional, + Tuple, TypeVar, overload, ) from urllib.parse import urlparse -import grpc +import aiohttp from typing_extensions import ParamSpec, TypeAlias import dispatch.coroutine +import dispatch.sdk.v1.call_pb2 as call_pb import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb -import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc +from dispatch.config import NamedValueFromEnvironment from dispatch.experimental.durable import durable from dispatch.id import DispatchID -from dispatch.proto import Arguments, Call, Error, Input, Output, TailCall +from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall from dispatch.scheduler import OneShotScheduler logger = logging.getLogger(__name__) +class GlobalSession(aiohttp.ClientSession): + async def __aexit__(self, *args): + pass # don't close global sessions when used as context managers + + +DEFAULT_API_URL: str = "https://api.dispatch.run" +DEFAULT_SESSION: Optional[aiohttp.ClientSession] = None + + +def current_session() -> aiohttp.ClientSession: + global DEFAULT_SESSION + if DEFAULT_SESSION is None: + DEFAULT_SESSION = GlobalSession() + return DEFAULT_SESSION + + +class ThreadContext(threading.local): + in_function_call: bool + + def __init__(self): + self.in_function_call = False + + +thread_context = ThreadContext() + + +def function(func: Callable[P, T]) -> Callable[P, T]: + def scope(*args: P.args, **kwargs: P.kwargs) -> T: + if thread_context.in_function_call: + raise RuntimeError("recursively entered a dispatch function entry point") + thread_context.in_function_call = True + try: + return func(*args, **kwargs) + finally: + thread_context.in_function_call = False + + return scope + + PrimitiveFunctionType: TypeAlias = Callable[[Input], Awaitable[Output]] """A primitive function is a function that accepts a dispatch.proto.Input and unconditionally returns a dispatch.proto.Output. It must not raise @@ -42,11 +85,12 @@ """ -DEFAULT_API_URL = "https://api.dispatch.run" - - class PrimitiveFunction: __slots__ = ("_endpoint", "_client", "_name", "_primitive_func") + _endpoint: str + _client: Client + _name: str + _primitive_function: PrimitiveFunctionType def __init__( self, @@ -75,8 +119,8 @@ def name(self) -> str: async def _primitive_call(self, input: Input) -> Output: return await self._primitive_func(input) - def _primitive_dispatch(self, input: Any = None) -> DispatchID: - [dispatch_id] = self._client.dispatch([self._build_primitive_call(input)]) + async def _primitive_dispatch(self, input: Any = None) -> DispatchID: + [dispatch_id] = await self._client.dispatch([self._build_primitive_call(input)]) return dispatch_id def _build_primitive_call( @@ -109,20 +153,24 @@ def __init__( primitive_func: PrimitiveFunctionType, ): PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func) - self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable( self._call_async ) async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T: - return await dispatch.coroutine.call( - self.build_call(*args, **kwargs, correlation_id=None) - ) + return await dispatch.coroutine.call(self.build_call(*args, **kwargs)) - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: """Call the function asynchronously (through Dispatch), and return a coroutine that can be awaited to retrieve the call result.""" - return self._func_indirect(*args, **kwargs) + if thread_context.in_function_call: + return await self._func_indirect(*args, **kwargs) + + call = self.build_call(*args, **kwargs) + + [dispatch_id] = await self._client.dispatch([call]) + + return await self._client.wait(dispatch_id) def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: """Dispatch an asynchronous call to the function without @@ -137,30 +185,21 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: Returns: DispatchID: ID of the dispatched call. - - Raises: - RuntimeError: if a Dispatch client has not been configured. """ - return self._primitive_dispatch(Arguments(args, kwargs)) + return asyncio.run(self._primitive_dispatch(Arguments(args, kwargs))) - def build_call( - self, *args: P.args, correlation_id: Optional[int] = None, **kwargs: P.kwargs - ) -> Call: + def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call: """Create a Call for this function with the provided input. Useful to generate calls when using the Client. Args: *args: Positional arguments for the function. - correlation_id: optional arbitrary integer the caller can use to - match this call to a call result. **kwargs: Keyword arguments for the function. Returns: Call: can be passed to Client.dispatch. """ - return self._build_primitive_call( - Arguments(args, kwargs), correlation_id=correlation_id - ) + return self._build_primitive_call(Arguments(args, kwargs)) class Reset(TailCall): @@ -168,7 +207,7 @@ class Reset(TailCall): the call embedded in this exception.""" def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs): - super().__init__(call=func.build_call(*args, correlation_id=None, **kwargs)) + super().__init__(call=func.build_call(*args, **kwargs)) class Registry: @@ -240,6 +279,7 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: func = durable(func) @wraps(func) + @function async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, func, *args, **kwargs) @@ -254,6 +294,7 @@ def _register_coroutine( func = durable(func) @wraps(func) + @function async def primitive_func(input: Input) -> Output: return await OneShotScheduler(func).run(input) @@ -261,7 +302,10 @@ async def primitive_func(input: Input) -> Output: durable_primitive_func = durable(primitive_func) wrapped_func = Function[P, T]( - self.endpoint, self.client, name, durable_primitive_func + self.endpoint, + self.client, + name, + durable_primitive_func, ) self._register(name, wrapped_func) return wrapped_func @@ -273,7 +317,10 @@ def primitive_function( name = primitive_func.__qualname__ logger.info("registering primitive function: %s", name) wrapped_func = PrimitiveFunction( - self.endpoint, self.client, name, primitive_func + self.endpoint, + self.client, + name, + primitive_func, ) self._register(name, wrapped_func) return wrapped_func @@ -283,17 +330,19 @@ def _register(self, name: str, wrapped_func: PrimitiveFunction): raise ValueError(f"function already registered with name '{name}'") self.functions[name] = wrapped_func + def batch(self): # -> Batch: + """Returns a Batch instance that can be used to build + a set of calls to dispatch.""" + # return self.client.batch() + raise NotImplemented + def set_client(self, client: Client): """Set the Client instance used to dispatch calls to registered functions.""" + # TODO: figure out a way to remove this method, it's only used in examples self.client = client for fn in self.functions.values(): fn._client = client - def batch(self) -> Batch: - """Returns a Batch instance that can be used to build - a set of calls to dispatch.""" - return self.client.batch() - def override_endpoint(self, endpoint: str): for fn in self.functions.values(): fn.endpoint = endpoint @@ -302,7 +351,10 @@ def override_endpoint(self, endpoint: str): class Client: """Client for the Dispatch API.""" - __slots__ = ("api_url", "api_key", "_stub", "api_key_from") + __slots__ = ("api_url", "api_key") + + api_url: NamedValueFromEnvironment + api_key: NamedValueFromEnvironment def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None): """Create a new Dispatch client. @@ -318,58 +370,45 @@ def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None) Raises: ValueError: if the API key is missing. """ + self.api_url = NamedValueFromEnvironment("DISPATCH_API_URL", "api_url", api_url) + self.api_key = NamedValueFromEnvironment("DISPATCH_API_KEY", "api_key", api_key) - if api_key: - self.api_key_from = "api_key" - else: - self.api_key_from = "DISPATCH_API_KEY" - api_key = os.environ.get("DISPATCH_API_KEY") - if not api_key: + if not self.api_key.value: raise ValueError( "missing API key: set it with the DISPATCH_API_KEY environment variable" ) - if not api_url: - api_url = os.environ.get("DISPATCH_API_URL", DEFAULT_API_URL) - if not api_url: - raise ValueError( - "missing API URL: set it with the DISPATCH_API_URL environment variable" - ) + if not self.api_url.value: + if "DISPATCH_API_URL" in os.environ: + raise ValueError( + "missing API URL: set it with the DISPATCH_API_URL environment variable" + ) + self.api_url._value = DEFAULT_API_URL - logger.debug("initializing client for Dispatch API at URL %s", api_url) - self.api_url = api_url - self.api_key = api_key - self._init_stub() - - def __getstate__(self): - return {"api_url": self.api_url, "api_key": self.api_key} - - def __setstate__(self, state): - self.api_url = state["api_url"] - self.api_key = state["api_key"] - self._init_stub() - - def _init_stub(self): - result = urlparse(self.api_url) - if result.scheme == "http": - creds = grpc.local_channel_credentials() - elif result.scheme == "https": - creds = grpc.ssl_channel_credentials() - else: + result = urlparse(self.api_url.value) + if result.scheme not in ("http", "https"): raise ValueError(f"Invalid API scheme: '{result.scheme}'") - call_creds = grpc.access_token_call_credentials(self.api_key) - creds = grpc.composite_channel_credentials(creds, call_creds) - channel = grpc.secure_channel(result.netloc, creds) - - self._stub = dispatch_grpc.DispatchServiceStub(channel) - - def batch(self) -> Batch: - """Returns a Batch instance that can be used to build - a set of calls to dispatch.""" - return Batch(self) + logger.debug( + "initializing client for Dispatch API at URL %s", self.api_url.value + ) - def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: + def session(self) -> aiohttp.ClientSession: + return current_session() + + def request( + self, path: str, timeout: int = 5 + ) -> Tuple[str, dict[str, str], aiohttp.ClientTimeout]: + # https://connectrpc.com/docs/protocol/#unary-request + headers = { + "Authorization": "Bearer " + self.api_key.value, + "Content-Type": "application/proto", + "Connect-Protocol-Version": "1", + "Connect-Timeout-Ms": str(timeout * 1000), + } + return self.api_url.value + path, headers, aiohttp.ClientTimeout(total=timeout) + + async def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: """Dispatch function calls. Args: @@ -380,17 +419,21 @@ def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: """ calls_proto = [c._as_proto() for c in calls] logger.debug("dispatching %d function call(s)", len(calls_proto)) - req = dispatch_pb.DispatchRequest(calls=calls_proto) + data = dispatch_pb.DispatchRequest(calls=calls_proto).SerializeToString() - try: - resp = self._stub.Dispatch(req) - except grpc.RpcError as e: - status_code = e.code() - if status_code == grpc.StatusCode.UNAUTHENTICATED: - raise PermissionError( - f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" - ) from e - raise + (url, headers, timeout) = self.request( + "/dispatch.sdk.v1.DispatchService/Dispatch" + ) + + async with self.session() as session: + async with session.post( + url, headers=headers, data=data, timeout=timeout + ) as res: + data = await res.read() + self._check_response(res.status, data) + + resp = dispatch_pb.DispatchResponse() + resp.ParseFromString(data) dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] if logger.isEnabledFor(logging.DEBUG): @@ -401,19 +444,70 @@ def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: ) return dispatch_ids + async def wait(self, dispatch_id: DispatchID) -> Any: + (url, headers, timeout) = self.request("/dispatch.sdk.v1.DispatchService/Wait") + data = dispatch_id.encode("utf-8") + + async with self.session() as session: + async with session.post( + url, headers=headers, data=data, timeout=timeout + ) as res: + data = await res.read() + self._check_response(res.status, data) + + resp = call_pb.CallResult() + resp.ParseFromString(data) + + result = CallResult._from_proto(resp) + if result.error is not None: + raise result.error.to_exception() + return result.output + + def _check_response(self, status: int, data: bytes): + if status == 200: + return + if status == 401: + raise PermissionError( + f"Dispatch received an invalid authentication token (check {self.api_key.name} is correct)" + ) + raise ClientError.from_response(status, data) + + +class ClientError(aiohttp.ClientError): + status: int + code: str + message: str + + def __init__( + self, status: int = 0, code: str = "unknown", message: str = "unknown" + ): + self.status = status + self.code = code + self.message = message + super().__init__(f"{code}: {message}") + + @classmethod + def from_response(cls, status: int, body: bytes) -> ClientError: + error_dict = json.loads(body) + error_code = str(error_dict.get("code")) or "unknown" + error_message = str(error_dict.get("message")) or "unknown" + return cls(status, error_code, error_message) + class Batch: """A batch of calls to dispatch.""" __slots__ = ("client", "calls") + client: Client + calls: List[Call] def __init__(self, client: Client): self.client = client - self.calls: List[Call] = [] + self.calls = [] def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch: """Add a call to the specified function to the batch.""" - return self.add_call(func.build_call(*args, correlation_id=None, **kwargs)) + return self.add_call(func.build_call(*args, **kwargs)) def add_call(self, call: Call) -> Batch: """Add a Call to the batch.""" @@ -431,11 +525,10 @@ def dispatch(self) -> List[DispatchID]: """ if not self.calls: return [] - - dispatch_ids = self.client.dispatch(self.calls) - self.reset() + dispatch_ids = asyncio.run(self.client.dispatch(self.calls)) + self.clear() return dispatch_ids - def reset(self): + def clear(self): """Reset the batch.""" self.calls = [] diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 19af8a71..bfc4c5e8 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -1,14 +1,15 @@ """Integration of Dispatch functions with http.""" +import asyncio import logging import os from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Mapping, Optional, Union +from typing import Iterable, List, Mapping, Optional, Union +from aiohttp import web from http_message_signatures import InvalidSignature -from dispatch.asyncio import Runner from dispatch.function import Registry from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb @@ -24,37 +25,6 @@ logger = logging.getLogger(__name__) -class Dispatch: - """A Dispatch instance servicing as a http server.""" - - registry: Registry - verification_key: Optional[Ed25519PublicKey] - - def __init__( - self, - registry: Registry, - verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - ): - """Initialize a Dispatch http handler. - - Args: - registry: The registry of functions to be serviced. - - verification_key: The verification key to use for requests. - """ - self.registry = registry - self.verification_key = parse_verification_key(verification_key) - - def __call__(self, request, client_address, server): - return FunctionService( - request, - client_address, - server, - registry=self.registry, - verification_key=self.verification_key, - ) - - class FunctionServiceError(Exception): __slots__ = ("status", "code", "message") @@ -124,17 +94,16 @@ def do_POST(self): url = self.requestline # TODO: need full URL try: - with Runner() as runner: - content = runner.run( - function_service_run( - url, - method, - dict(self.headers), - data, - self.registry, - self.verification_key, - ) + content = asyncio.run( + function_service_run( + url, + method, + dict(self.headers), + data, + self.registry, + self.verification_key, ) + ) except FunctionServiceError as e: return self.send_error_response(e.status, e.code, e.message) @@ -144,6 +113,145 @@ def do_POST(self): self.wfile.write(content) +class Dispatch(web.Application): + """A Dispatch instance servicing as a http server.""" + + registry: Registry + verification_key: Optional[Ed25519PublicKey] + + def __init__( + self, + registry: Registry, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + ): + """Initialize a Dispatch application. + + Args: + registry: The registry of functions to be serviced. + + verification_key: The verification key to use for requests. + """ + super().__init__() + self.registry = registry + self.verification_key = parse_verification_key(verification_key) + self.add_routes( + [ + web.post( + "/dispatch.sdk.v1.FunctionService/Run", self.handle_run_request + ), + ] + ) + + def __call__(self, request, client_address, server): + return FunctionService( + request, + client_address, + server, + registry=self.registry, + verification_key=self.verification_key, + ) + + async def handle_run_request(self, request: web.Request) -> web.Response: + return await function_service_run_handler( + request, self.registry, self.verification_key + ) + + +class Server: + host: str + port: int + app: web.Application + + _runner: web.AppRunner + _site: web.TCPSite + + def __init__(self, host: str, port: int, app: web.Application): + self.host = host + self.port = port + self.app = app + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.stop() + + async def start(self): + self._runner = web.AppRunner(self.app) + await self._runner.setup() + + self._site = web.TCPSite(self._runner, self.host, self.port) + await self._site.start() + + if self.port == 0: + assert self._site._server is not None + assert hasattr(self._site._server, "sockets") + sockets = self._site._server.sockets + self.port = sockets[0].getsockname()[1] if sockets else 0 + + async def stop(self): + await self._site.stop() + await self._runner.cleanup() + + +def make_error_response_body(code: str, message: str) -> bytes: + return f'{{"code":"{code}","message":"{message}"}}'.encode() + + +def make_error_response(status: int, code: str, message: str) -> web.Response: + body = make_error_response_body(code, message) + return web.Response(status=status, content_type="application/json", body=body) + + +def make_error_response_invalid_argument(message: str) -> web.Response: + return make_error_response(400, "invalid_argument", message) + + +def make_error_response_not_found(message: str) -> web.Response: + return make_error_response(404, "not_found", message) + + +def make_error_response_unauthenticated(message: str) -> web.Response: + return make_error_response(401, "unauthenticated", message) + + +def make_error_response_permission_denied(message: str) -> web.Response: + return make_error_response(403, "permission_denied", message) + + +def make_error_response_internal(message: str) -> web.Response: + return make_error_response(500, "internal", message) + + +async def function_service_run_handler( + request: web.Request, + function_registry: Registry, + verification_key: Optional[Ed25519PublicKey], +) -> web.Response: + content_length = request.content_length + if content_length is None or content_length == 0: + return make_error_response_invalid_argument("content length is required") + if content_length < 0: + return make_error_response_invalid_argument("content length is negative") + if content_length > 16_000_000: + return make_error_response_invalid_argument("content length is too large") + + data: bytes = await request.read() + try: + content = await function_service_run( + str(request.url), + request.method, + dict(request.headers), + data, + function_registry, + verification_key, + ) + except FunctionServiceError as e: + return make_error_response(e.status, e.code, e.message) + return web.Response(status=200, content_type="application/proto", body=content) + + async def function_service_run( url: str, method: str, @@ -237,7 +345,3 @@ async def function_service_run( logger.debug("finished handling run request with status %s", status.name) return response.SerializeToString() - - -def make_error_response_body(code: str, message: str) -> bytes: - return f'{{"code":"{code}","message":"{message}"}}'.encode() diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index da8a3f74..f8179324 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -1,5 +1,265 @@ +import asyncio +import unittest +from datetime import datetime, timedelta +from functools import wraps +from typing import Any, Callable, Coroutine, Optional, TypeVar, overload + +import aiohttp +from aiohttp import web +from google.protobuf.timestamp_pb2 import Timestamp + +import dispatch.experimental.durable.registry +from dispatch.function import Client as BaseClient +from dispatch.function import Registry as BaseRegistry +from dispatch.http import Dispatch +from dispatch.http import Server as BaseServer +from dispatch.sdk.v1.call_pb2 import Call, CallResult +from dispatch.sdk.v1.dispatch_pb2 import DispatchRequest, DispatchResponse +from dispatch.sdk.v1.error_pb2 import Error +from dispatch.sdk.v1.function_pb2 import RunRequest, RunResponse +from dispatch.sdk.v1.poll_pb2 import PollResult +from dispatch.sdk.v1.status_pb2 import STATUS_OK + from .client import EndpointClient from .server import DispatchServer from .service import DispatchService -__all__ = ["EndpointClient", "DispatchServer", "DispatchService"] +__all__ = [ + "EndpointClient", + "DispatchServer", + "DispatchService", + "function", + "method", + "main", + "run", + "Client", + "Server", + "Service", + "Registry", + "DISPATCH_ENDPOINT_URL", + "DISPATCH_API_URL", + "DISPATCH_API_KEY", +] + +R = TypeVar("R", bound=BaseRegistry) +T = TypeVar("T") + +DISPATCH_ENDPOINT_URL = "http://localhost:0" +DISPATCH_API_URL = "http://localhost:0" +DISPATCH_API_KEY = "916CC3D280BB46DDBDA984B3DD10059A" + + +class Client(BaseClient): + def session(self) -> aiohttp.ClientSession: + # Use an individual sessionn in the test client instead of the default + # global session from dispatch.http so we don't crash when a different + # event loop is employed in each test. + return aiohttp.ClientSession() + + +class Registry(BaseRegistry): + def __init__(self): + # placeholder values to initialize the base class prior to binding + # random ports. + super().__init__( + endpoint=DISPATCH_ENDPOINT_URL, + api_url=DISPATCH_API_URL, + api_key=DISPATCH_API_KEY, + ) + + +class Server(BaseServer): + def __init__(self, app: web.Application): + super().__init__("localhost", 0, app) + + @property + def url(self): + return f"http://{self.host}:{self.port}" + + +class Service(web.Application): + tasks: dict[str, asyncio.Task[CallResult]] + + def __init__(self): + super().__init__() + self.dispatch_ids = (str(i) for i in range(2**32 - 1)) + self.tasks = {} + self.add_routes( + [ + web.post( + "/dispatch.sdk.v1.DispatchService/Dispatch", + self.handle_dispatch_request, + ), + web.post( + "/dispatch.sdk.v1.DispatchService/Wait", + self.handle_wait_request, + ), + ] + ) + + async def authenticate(self, request: web.Request): + auth = request.headers.get("Authorization") + if not auth or not auth.startswith("Bearer "): + raise web.HTTPUnauthorized(text="missing authentication token") + + token = auth[len("Bearer ") :] + if token != DISPATCH_API_KEY: + raise web.HTTPUnauthorized(text="invalid authentication token") + + async def handle_dispatch_request(self, request: web.Request): + await self.authenticate(request) + req = DispatchRequest.FromString(await request.read()) + async with aiohttp.ClientSession() as session: + res = await self.dispatch(session, req) + return web.Response( + content_type="application/proto", body=res.SerializeToString() + ) + + async def handle_wait_request(self, request: web.Request): + await self.authenticate(request) + req = str(await request.read(), "utf-8") + res = await self.wait(req) + return web.Response( + content_type="application/proto", body=res.SerializeToString() + ) + + async def dispatch( + self, session: aiohttp.ClientSession, req: DispatchRequest + ) -> DispatchResponse: + dispatch_ids = [next(self.dispatch_ids) for _ in req.calls] + + for call, dispatch_id in zip(req.calls, dispatch_ids): + self.tasks[dispatch_id] = asyncio.create_task( + self.call(session, call, dispatch_id) + ) + + return DispatchResponse(dispatch_ids=dispatch_ids) + + # TODO: add to protobuf definitions + async def wait(self, dispatch_id: str) -> CallResult: + return await self.tasks[dispatch_id] + + async def call( + self, + session: aiohttp.ClientSession, + call: Call, + dispatch_id: str, + parent_dispatch_id: Optional[str] = None, + root_dispatch_id: Optional[str] = None, + ) -> CallResult: + root_dispatch_id = root_dispatch_id or dispatch_id + + now = datetime.now() + exp = now + ( + timedelta( + seconds=call.expiration.seconds, + microseconds=call.expiration.nanos // 1000, + ) + if call.expiration + else timedelta(seconds=60) + ) + + creation_time = Timestamp() + creation_time.FromDatetime(now) + + expiration_time = Timestamp() + expiration_time.FromDatetime(exp) + + req = RunRequest( + function=call.function, + input=call.input, + creation_time=creation_time, + expiration_time=expiration_time, + dispatch_id=dispatch_id, + parent_dispatch_id=parent_dispatch_id, + root_dispatch_id=root_dispatch_id, + ) + + endpoint = call.endpoint + while True: + res = await self.run(session, endpoint, req) + + if res.status != STATUS_OK: + # TODO: emulate retries etc... + return CallResult( + dispatch_id=dispatch_id, + error=Error(type="status", message=str(res.status)), + ) + + if res.exit: + if res.exit.tail_call: + req.function = res.exit.tail_call.function + req.input = res.exit.tail_call.input + req.poll_result = None # type: ignore + continue + return CallResult( + dispatch_id=dispatch_id, + output=res.exit.result.output, + error=res.exit.result.error, + ) + + for call in res.poll.calls: + if not call.endpoint: + call.endpoint = endpoint + + # TODO: enforce poll limits + req.input = None # type: ignore + req.poll_result = PollResult( + coroutine_state=res.poll.coroutine_state, + results=await asyncio.gather( + *[ + self.call(session, call, dispatch_id) + for call, dispatch_id in zip( + res.poll.calls, next(self.dispatch_ids) + ) + ] + ), + ) + + async def run( + self, session: aiohttp.ClientSession, endpoint: str, req: RunRequest + ) -> RunResponse: + async with await session.post( + f"{endpoint}/dispatch.sdk.v1.FunctionService/Run", + data=req.SerializeToString(), + ) as response: + return RunResponse.FromString(await response.read()) + + +async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: + api = Service() + app = Dispatch(reg) + try: + async with Server(api) as backend: + async with Server(app) as server: + # Here we break through the abstraction layers a bit, it's not + # ideal but it works for now. + reg.client.api_url.value = backend.url + reg.endpoint = server.url + await fn(reg) + finally: + # TODO: let's figure out how to get rid of this global registry + # state at some point, which forces tests to be run sequentially. + dispatch.experimental.durable.registry.clear_functions() + + +def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: + return asyncio.run(main(reg, fn)) + + +def function(fn: Callable[[Registry], Coroutine[Any, Any, None]]) -> Callable[[], None]: + @wraps(fn) + def wrapper(): + return run(Registry(), fn) + + return wrapper + + +def method( + fn: Callable[[T, Registry], Coroutine[Any, Any, None]] +) -> Callable[[T], None]: + @wraps(fn) + def wrapper(self: T): + return run(Registry(), lambda reg: fn(self, reg)) + + return wrapper diff --git a/src/dispatch/test/http.py b/src/dispatch/test/http.py index d811a763..cf7ba9fa 100644 --- a/src/dispatch/test/http.py +++ b/src/dispatch/test/http.py @@ -1,6 +1,10 @@ from dataclasses import dataclass from typing import Mapping, Protocol +import aiohttp + +from dispatch.function import Client as DefaultClient + @dataclass class HttpResponse(Protocol): diff --git a/tests/dispatch/test_config.py b/tests/dispatch/test_config.py new file mode 100644 index 00000000..52dc8651 --- /dev/null +++ b/tests/dispatch/test_config.py @@ -0,0 +1,44 @@ +import os +import pickle +from unittest import mock + +from dispatch.config import NamedValueFromEnvironment + + +def test_value_preset(): + v = NamedValueFromEnvironment("FOO", "foo", "bar") + assert v.name == "foo" + assert v.value == "bar" + + +@mock.patch.dict(os.environ, {"FOO": "bar"}) +def test_value_from_envvar(): + v = NamedValueFromEnvironment("FOO", "foo") + assert v.name == "FOO" + assert v.value == "bar" + + +@mock.patch.dict(os.environ, {"FOO": "bar"}) +def test_value_pickle_reload_from_preset(): + v = NamedValueFromEnvironment("FOO", "foo", "hi!") + assert v.name == "foo" + assert v.value == "hi!" + + s = pickle.dumps(v) + v = pickle.loads(s) + assert v.name == "foo" + assert v.value == "hi!" + + +@mock.patch.dict(os.environ, {"FOO": "bar"}) +def test_value_pickle_reload_from_envvar(): + v = NamedValueFromEnvironment("FOO", "foo") + assert v.name == "FOO" + assert v.value == "bar" + + s = pickle.dumps(v) + os.environ["FOO"] = "baz" + + v = pickle.loads(s) + assert v.name == "FOO" + assert v.value == "baz" diff --git a/tests/dispatch/test_error.py b/tests/dispatch/test_error.py index 2ca14e43..fa801194 100644 --- a/tests/dispatch/test_error.py +++ b/tests/dispatch/test_error.py @@ -1,72 +1,86 @@ import traceback -import unittest -from dispatch.proto import Error +import pytest +from dispatch.proto import Error, Status -class TestError(unittest.TestCase): - def test_conversion_between_exception_and_error(self): - try: - raise ValueError("test") - except Exception as e: - original_exception = e - error = Error.from_exception(e) - original_traceback = "".join( - traceback.format_exception( - original_exception.__class__, - original_exception, - original_exception.__traceback__, - ) + +def test_error_with_ok_status(): + with pytest.raises(ValueError): + Error(Status.OK, type="type", message="yep") + + +def test_from_exception_timeout(): + err = Error.from_exception(TimeoutError()) + assert Status.TIMEOUT == err.status + + +def test_from_exception_syntax_error(): + err = Error.from_exception(SyntaxError()) + assert Status.PERMANENT_ERROR == err.status + + +def test_conversion_between_exception_and_error(): + try: + raise ValueError("test") + except Exception as e: + original_exception = e + error = Error.from_exception(e) + original_traceback = "".join( + traceback.format_exception( + original_exception.__class__, + original_exception, + original_exception.__traceback__, ) + ) - # For some reasons traceback.format_exception does not include the caret - # (^) in the original traceback, but it does in the reconstructed one, - # so we strip it out to be able to compare the two. - def strip_caret(s): - return "\n".join( - [l for l in s.split("\n") if not l.strip().startswith("^")] - ) + # For some reasons traceback.format_exception does not include the caret + # (^) in the original traceback, but it does in the reconstructed one, + # so we strip it out to be able to compare the two. + def strip_caret(s): + return "\n".join([l for l in s.split("\n") if not l.strip().startswith("^")]) - reconstructed_exception = error.to_exception() - reconstructed_traceback = strip_caret( - "".join( - traceback.format_exception( - reconstructed_exception.__class__, - reconstructed_exception, - reconstructed_exception.__traceback__, - ) + reconstructed_exception = error.to_exception() + reconstructed_traceback = strip_caret( + "".join( + traceback.format_exception( + reconstructed_exception.__class__, + reconstructed_exception, + reconstructed_exception.__traceback__, ) ) + ) - assert type(reconstructed_exception) is type(original_exception) - assert str(reconstructed_exception) == str(original_exception) - assert original_traceback == reconstructed_traceback - - error2 = Error.from_exception(reconstructed_exception) - reconstructed_exception2 = error2.to_exception() - reconstructed_traceback2 = strip_caret( - "".join( - traceback.format_exception( - reconstructed_exception2.__class__, - reconstructed_exception2, - reconstructed_exception2.__traceback__, - ) + assert type(reconstructed_exception) is type(original_exception) + assert str(reconstructed_exception) == str(original_exception) + assert original_traceback == reconstructed_traceback + + error2 = Error.from_exception(reconstructed_exception) + reconstructed_exception2 = error2.to_exception() + reconstructed_traceback2 = strip_caret( + "".join( + traceback.format_exception( + reconstructed_exception2.__class__, + reconstructed_exception2, + reconstructed_exception2.__traceback__, ) ) + ) + + assert type(reconstructed_exception2) is type(original_exception) + assert str(reconstructed_exception2) == str(original_exception) + assert original_traceback == reconstructed_traceback2 - assert type(reconstructed_exception2) is type(original_exception) - assert str(reconstructed_exception2) == str(original_exception) - assert original_traceback == reconstructed_traceback2 - def test_conversion_without_traceback(self): - try: - raise ValueError("test") - except Exception as e: - original_exception = e +def test_conversion_without_traceback(): + try: + raise ValueError("test") + except Exception as e: + original_exception = e - error = Error.from_exception(original_exception) - error.traceback = b"" + error = Error.from_exception(original_exception) + error.traceback = b"" - reconstructed_exception = error.to_exception() - assert type(reconstructed_exception) is type(original_exception) - assert str(reconstructed_exception) == str(original_exception) + reconstructed_exception = error.to_exception() + assert type(reconstructed_exception) is type(original_exception) + assert str(reconstructed_exception) == str(original_exception) diff --git a/tests/dispatch/test_function.py b/tests/dispatch/test_function.py index 0befc0f7..3550b4b5 100644 --- a/tests/dispatch/test_function.py +++ b/tests/dispatch/test_function.py @@ -1,21 +1,14 @@ import pickle -import unittest -from dispatch.function import Client, Registry +from dispatch.test import Registry -class TestFunction(unittest.TestCase): - def setUp(self): - self.dispatch = Registry( - endpoint="http://example.com", - api_url="http://dispatch.com", - api_key="foobar", - ) +def test_serializable(): + reg = Registry() - def test_serializable(self): - @self.dispatch.function - def my_function(): - pass + @reg.function + def my_function(): + pass - s = pickle.dumps(my_function) - pickle.loads(s) + s = pickle.dumps(my_function) + pickle.loads(s) diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 554b6339..e82d0db4 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -1,7 +1,8 @@ import unittest from typing import Any, Callable, List, Optional, Set, Type -from dispatch.asyncio import Runner +import pytest + from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall @@ -53,392 +54,445 @@ async def raises_error(): raise ValueError("oops") -class TestOneShotScheduler(unittest.TestCase): - def setUp(self): - self.runner = Runner() - - def tearDown(self): - self.runner.close() +@pytest.mark.asyncio +async def test_main_return(): + @durable + async def main(): + return 1 - def test_main_return(self): - @durable - async def main(): - return 1 + output = await start(main) + assert_exit_result_value(output, 1) - output = self.start(main) - self.assert_exit_result_value(output, 1) - def test_main_raise(self): - @durable - async def main(): - raise ValueError("oops") +@pytest.mark.asyncio +async def test_main_raise(): + @durable + async def main(): + raise ValueError("oops") - output = self.start(main) - self.assert_exit_result_error(output, ValueError, "oops") + output = await start(main) + assert_exit_result_error(output, ValueError, "oops") - def test_main_args(self): - @durable - async def main(a, b=1): - return a + b - output = self.start(main, 2, b=10) - self.assert_exit_result_value(output, 12) +@pytest.mark.asyncio +async def test_main_args(): + @durable + async def main(a, b=1): + return a + b - def test_call_one(self): - output = self.start(call_one, "foo") + output = await start(main, 2, b=10) + assert_exit_result_value(output, 12) - self.assert_poll_call_functions(output, ["foo"]) - def test_call_concurrently(self): - output = self.start(call_concurrently, "foo", "bar", "baz") +@pytest.mark.asyncio +async def test_call_one(): + output = await start(call_one, "foo") - self.assert_poll_call_functions(output, ["foo", "bar", "baz"]) + assert_poll_call_functions(output, ["foo"]) - def test_call_one_indirect(self): - @durable - async def main(): - return await call_one("foo") - output = self.start(main) +@pytest.mark.asyncio +async def test_call_concurrently(): + output = await start(call_concurrently, "foo", "bar", "baz") - self.assert_poll_call_functions(output, ["foo"]) + assert_poll_call_functions(output, ["foo", "bar", "baz"]) - def test_call_concurrently_indirect(self): - @durable - async def main(*functions): - return await call_concurrently(*functions) - output = self.start(main, "foo", "bar", "baz") +@pytest.mark.asyncio +async def test_call_one_indirect(): + @durable + async def main(): + return await call_one("foo") - self.assert_poll_call_functions(output, ["foo", "bar", "baz"]) + output = await start(main) - def test_depth_run(self): - @durable - async def main(): - return await gather( - call_concurrently("a", "b", "c"), - call_one("d"), - call_concurrently("e", "f", "g"), - call_one("h"), - ) + assert_poll_call_functions(output, ["foo"]) - output = self.start(main) - # In this test, the output is deterministic, but it does not follow the - # order in which the coroutines are declared due to interleaving of the - # asyncio event loop. - # - # Note that the order could change between Python versions, so we might - # choose to remove this test, or adapt it in the future. - self.assert_poll_call_functions( - output, - ["d", "h", "e", "f", "g", "a", "b", "c"], - min_results=1, - max_results=8, - ) - def test_resume_after_call(self): - @durable - async def main(): - result1 = await call_one("foo") - result2 = await call_one("bar") - return result1 + result2 +@pytest.mark.asyncio +async def test_call_concurrently_indirect(): + @durable + async def main(*functions): + return await call_concurrently(*functions) - output = self.start(main) - calls = self.assert_poll_call_functions(output, ["foo"]) - output = self.resume( - main, - output, - [CallResult.from_value(1, correlation_id=calls[0].correlation_id)], - ) - calls = self.assert_poll_call_functions(output, ["bar"]) - output = self.resume( - main, - output, - [CallResult.from_value(2, correlation_id=calls[0].correlation_id)], - ) - self.assert_exit_result_value(output, 3) - - def test_resume_after_gather_all_at_once(self): - @durable - async def main(): - return sum(await call_concurrently("a", "b", "c", "d")) - - output = self.start(main) - calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) - results = [ - CallResult.from_value(i, correlation_id=call.correlation_id) - for i, call in enumerate(calls) - ] - output = self.resume(main, output, results) - self.assert_exit_result_value(output, 0 + 1 + 2 + 3) - - def test_resume_after_gather_one_at_a_time(self): - @durable - async def main(): - return sum(await call_concurrently("a", "b", "c", "d")) - - output = self.start(main) - calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) - for i, call in enumerate(calls): - output = self.resume( - main, - output, - [CallResult.from_value(i, correlation_id=call.correlation_id)], - ) - if i < len(calls) - 1: - self.assert_empty_poll(output) - - self.assert_exit_result_value(output, 0 + 1 + 2 + 3) - - def test_resume_after_any_result(self): - @durable - async def main(): - return await call_any("a", "b", "c", "d") - - output = self.start(main) - calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) - - output = self.resume( - main, - output, - [CallResult.from_value(23, correlation_id=calls[1].correlation_id)], - ) - self.assert_exit_result_value(output, 23) - - def test_resume_after_all_errors(self): - @durable - async def main(): - return await call_any("a", "b", "c", "d") - - output = self.start(main) - calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) - results = [ - CallResult.from_error( - Error.from_exception(RuntimeError(f"oops{i}")), - correlation_id=call.correlation_id, - ) - for i, call in enumerate(calls) - ] - output = self.resume(main, output, results) - self.assert_exit_result_error( - output, AnyException, "4 coroutine(s) failed with an exception" - ) + output = await start(main, "foo", "bar", "baz") - def test_resume_after_race_result(self): - @durable - async def main(): - return await call_race("a", "b", "c", "d") + assert_poll_call_functions(output, ["foo", "bar", "baz"]) - output = self.start(main) - calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) - output = self.resume( - main, - output, - [CallResult.from_value(23, correlation_id=calls[1].correlation_id)], +@pytest.mark.asyncio +async def test_depth_run(): + @durable + async def main(): + return await gather( + call_concurrently("a", "b", "c"), + call_one("d"), + call_concurrently("e", "f", "g"), + call_one("h"), ) - self.assert_exit_result_value(output, 23) - - def test_resume_after_race_error(self): - @durable - async def main(): - return await call_race("a", "b", "c", "d") - - output = self.start(main) - calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) - error = Error.from_exception(RuntimeError("oops")) - output = self.resume( - main, - output, - [CallResult.from_error(error, correlation_id=calls[2].correlation_id)], - ) - self.assert_exit_result_error(output, RuntimeError, "oops") - - def test_dag(self): - @durable - async def main(): - result1 = await gather( - call_sequentially("a", "e"), - call_one("b"), - call_concurrently("c", "d"), - ) - result2 = await call_one("f") - result3 = await call_concurrently("g", "h") - return [result1, result2, result3] - - correlation_ids: Set[int] = set() - - output = self.start(main) - # a, b, c, d are called first. e is not because it depends on a. - calls = self.assert_poll_call_functions( - output, ["a", "b", "c", "d"], min_results=1, max_results=4 - ) - correlation_ids.update(call.correlation_id for call in calls) - results = [ - CallResult.from_value(i, correlation_id=call.correlation_id) - for i, call in enumerate(calls) - ] - output = self.resume(main, output, results) - # e is called next - calls = self.assert_poll_call_functions( - output, ["e"], min_results=1, max_results=1 - ) - correlation_ids.update(call.correlation_id for call in calls) - output = self.resume( - main, - output, - [CallResult.from_value(4, correlation_id=calls[0].correlation_id)], - ) - # f is called next - calls = self.assert_poll_call_functions( - output, ["f"], min_results=1, max_results=1 - ) - correlation_ids.update(call.correlation_id for call in calls) - output = self.resume( + output = await start(main) + # In this test, the output is deterministic, but it does not follow the + # order in which the coroutines are declared due to interleaving of the + # asyncio event loop. + # + # Note that the order could change between Python versions, so we might + # choose to remove this test, or adapt it in the future. + assert_poll_call_functions( + output, + ["d", "h", "e", "f", "g", "a", "b", "c"], + min_results=1, + max_results=8, + ) + + +@pytest.mark.asyncio +async def test_resume_after_call(): + @durable + async def main(): + result1 = await call_one("foo") + result2 = await call_one("bar") + return result1 + result2 + + output = await start(main) + calls = assert_poll_call_functions(output, ["foo"]) + output = await resume( + main, + output, + [CallResult.from_value(1, correlation_id=calls[0].correlation_id)], + ) + calls = assert_poll_call_functions(output, ["bar"]) + output = await resume( + main, + output, + [CallResult.from_value(2, correlation_id=calls[0].correlation_id)], + ) + assert_exit_result_value(output, 3) + + +@pytest.mark.asyncio +async def test_resume_after_gather_all_at_once(): + @durable + async def main(): + return sum(await call_concurrently("a", "b", "c", "d")) + + output = await start(main) + calls = assert_poll_call_functions(output, ["a", "b", "c", "d"]) + results = [ + CallResult.from_value(i, correlation_id=call.correlation_id) + for i, call in enumerate(calls) + ] + output = await resume(main, output, results) + assert_exit_result_value(output, 0 + 1 + 2 + 3) + + +@pytest.mark.asyncio +async def test_resume_after_gather_one_at_a_time(): + @durable + async def main(): + return sum(await call_concurrently("a", "b", "c", "d")) + + output = await start(main) + calls = assert_poll_call_functions(output, ["a", "b", "c", "d"]) + for i, call in enumerate(calls): + output = await resume( main, output, - [CallResult.from_value(5, correlation_id=calls[0].correlation_id)], - ) - # g, h are called next - calls = self.assert_poll_call_functions( - output, ["g", "h"], min_results=1, max_results=2 - ) - correlation_ids.update(call.correlation_id for call in calls) - output = self.resume( - main, - output, - [ - CallResult.from_value(6, correlation_id=calls[0].correlation_id), - CallResult.from_value(7, correlation_id=calls[1].correlation_id), - ], - ) - self.assert_exit_result_value( - output, - [ - [[0, 4], 1, [2, 3]], # result1 = (a, e), b, (c, d) - 5, # result2 = f - [6, 7], # result3 = (g, h) - ], + [CallResult.from_value(i, correlation_id=call.correlation_id)], ) + if i < len(calls) - 1: + assert_empty_poll(output) - self.assertEqual(len(correlation_ids), 8) - - def test_poll_error(self): - # The purpose of the test is to ensure that when a poll error occurs, - # we only abort the calls that were made on the previous yield. Any - # other in-flight calls from previous yields are not affected. - - @durable - async def c_then_d(): - c_result = await call_one("c") - try: - # The poll error will affect this call only. - d_result = await call_one("d") - except RuntimeError as e: - assert str(e) == "too many calls" - d_result = 100 - return c_result + d_result - - @durable - async def main(c_then_d): - return await gather( - call_one("a"), - call_one("b"), - c_then_d(), - ) - - output = self.start(main, c_then_d) - calls = self.assert_poll_call_functions( - output, ["a", "b", "c"], min_results=1, max_results=3 - ) - - call_a, call_b, call_c = calls - a_result, b_result, c_result = 10, 20, 30 - output = self.resume( - main, - output, - [CallResult.from_value(c_result, correlation_id=call_c.correlation_id)], - ) - self.assert_poll_call_functions(output, ["d"], min_results=1, max_results=3) - - output = self.resume( - main, output, [], poll_error=RuntimeError("too many calls") - ) - self.assert_poll_call_functions(output, []) - output = self.resume( - main, - output, - [ - CallResult.from_value(a_result, correlation_id=call_a.correlation_id), - CallResult.from_value(b_result, correlation_id=call_b.correlation_id), - ], - ) + assert_exit_result_value(output, 0 + 1 + 2 + 3) - self.assert_exit_result_value(output, [a_result, b_result, c_result + 100]) - def test_raise_indirect(self): - @durable - async def main(): - return await gather(call_one("a"), raises_error()) +@pytest.mark.asyncio +async def test_resume_after_any_result(): + @durable + async def main(): + return await call_any("a", "b", "c", "d") - output = self.start(main) - self.assert_exit_result_error(output, ValueError, "oops") + output = await start(main) + calls = assert_poll_call_functions(output, ["a", "b", "c", "d"]) - def test_raise_reset(self): - @durable - async def main(x: int, y: int): - raise TailCall( - call=Call( - function="main", input=Arguments((), {"x": x + 1, "y": y + 2}) - ) - ) + output = await resume( + main, + output, + [CallResult.from_value(23, correlation_id=calls[1].correlation_id)], + ) + assert_exit_result_value(output, 23) - output = self.start(main, x=1, y=2) - self.assert_exit_tail_call( - output, - tail_call=Call(function="main", input=Arguments((), {"x": 2, "y": 4})), - ) - def test_min_max_results_clamping(self): - @durable - async def main(): - return await call_concurrently("a", "b", "c") +@pytest.mark.asyncio +async def test_resume_after_all_errors(): + @durable + async def main(): + return await call_any("a", "b", "c", "d") - output = self.start(main, poll_min_results=1, poll_max_results=10) - self.assert_poll_call_functions( - output, ["a", "b", "c"], min_results=1, max_results=3 + output = await start(main) + calls = assert_poll_call_functions(output, ["a", "b", "c", "d"]) + results = [ + CallResult.from_error( + Error.from_exception(RuntimeError(f"oops{i}")), + correlation_id=call.correlation_id, ) - - output = self.start(main, poll_min_results=1, poll_max_results=2) - self.assert_poll_call_functions( - output, ["a", "b", "c"], min_results=1, max_results=2 + for i, call in enumerate(calls) + ] + output = await resume(main, output, results) + assert_exit_result_error( + output, AnyException, "4 coroutine(s) failed with an exception" + ) + + +@pytest.mark.asyncio +async def test_resume_after_race_result(): + @durable + async def main(): + return await call_race("a", "b", "c", "d") + + output = await start(main) + calls = assert_poll_call_functions(output, ["a", "b", "c", "d"]) + + output = await resume( + main, + output, + [CallResult.from_value(23, correlation_id=calls[1].correlation_id)], + ) + assert_exit_result_value(output, 23) + + +@pytest.mark.asyncio +async def test_resume_after_race_error(): + @durable + async def main(): + return await call_race("a", "b", "c", "d") + + output = await start(main) + calls = assert_poll_call_functions(output, ["a", "b", "c", "d"]) + + error = Error.from_exception(RuntimeError("oops")) + output = await resume( + main, + output, + [CallResult.from_error(error, correlation_id=calls[2].correlation_id)], + ) + assert_exit_result_error(output, RuntimeError, "oops") + + +@pytest.mark.asyncio +async def test_dag(): + @durable + async def main(): + result1 = await gather( + call_sequentially("a", "e"), + call_one("b"), + call_concurrently("c", "d"), ) - - output = self.start(main, poll_min_results=10, poll_max_results=10) - self.assert_poll_call_functions( - output, ["a", "b", "c"], min_results=3, max_results=3 + result2 = await call_one("f") + result3 = await call_concurrently("g", "h") + return [result1, result2, result3] + + correlation_ids: Set[int] = set() + + output = await start(main) + # a, b, c, d are called first. e is not because it depends on a. + calls = assert_poll_call_functions( + output, ["a", "b", "c", "d"], min_results=1, max_results=4 + ) + correlation_ids.update(call.correlation_id for call in calls) + results = [ + CallResult.from_value(i, correlation_id=call.correlation_id) + for i, call in enumerate(calls) + ] + output = await resume(main, output, results) + # e is called next + calls = assert_poll_call_functions(output, ["e"], min_results=1, max_results=1) + correlation_ids.update(call.correlation_id for call in calls) + output = await resume( + main, + output, + [CallResult.from_value(4, correlation_id=calls[0].correlation_id)], + ) + # f is called next + calls = assert_poll_call_functions(output, ["f"], min_results=1, max_results=1) + correlation_ids.update(call.correlation_id for call in calls) + output = await resume( + main, + output, + [CallResult.from_value(5, correlation_id=calls[0].correlation_id)], + ) + # g, h are called next + calls = assert_poll_call_functions(output, ["g", "h"], min_results=1, max_results=2) + correlation_ids.update(call.correlation_id for call in calls) + output = await resume( + main, + output, + [ + CallResult.from_value(6, correlation_id=calls[0].correlation_id), + CallResult.from_value(7, correlation_id=calls[1].correlation_id), + ], + ) + assert_exit_result_value( + output, + [ + [[0, 4], 1, [2, 3]], # result1 = (a, e), b, (c, d) + 5, # result2 = f + [6, 7], # result3 = (g, h) + ], + ) + + assert len(correlation_ids) == 8 + + +@pytest.mark.asyncio +async def test_poll_error(): + # The purpose of the test is to ensure that when a poll error occurs, + # we only abort the calls that were made on the previous yield. Any + # other in-flight calls from previous yields are not affected. + + @durable + async def c_then_d(): + c_result = await call_one("c") + try: + # The poll error will affect this call only. + d_result = await call_one("d") + except RuntimeError as e: + assert str(e) == "too many calls" + d_result = 100 + return c_result + d_result + + @durable + async def main(c_then_d): + return await gather( + call_one("a"), + call_one("b"), + c_then_d(), ) - def start( - self, - main: Callable, - *args: Any, - poll_min_results=1, - poll_max_results=10, - poll_max_wait_seconds=None, - **kwargs: Any, - ) -> Output: - input = Input.from_input_arguments(main.__qualname__, *args, **kwargs) - return self.runner.run( - OneShotScheduler( - main, - poll_min_results=poll_min_results, - poll_max_results=poll_max_results, - poll_max_wait_seconds=poll_max_wait_seconds, - ).run(input) + output = await start(main, c_then_d) + calls = assert_poll_call_functions( + output, ["a", "b", "c"], min_results=1, max_results=3 + ) + + call_a, call_b, call_c = calls + a_result, b_result, c_result = 10, 20, 30 + output = await resume( + main, + output, + [CallResult.from_value(c_result, correlation_id=call_c.correlation_id)], + ) + assert_poll_call_functions(output, ["d"], min_results=1, max_results=3) + + output = await resume(main, output, [], poll_error=RuntimeError("too many calls")) + assert_poll_call_functions(output, []) + output = await resume( + main, + output, + [ + CallResult.from_value(a_result, correlation_id=call_a.correlation_id), + CallResult.from_value(b_result, correlation_id=call_b.correlation_id), + ], + ) + + assert_exit_result_value(output, [a_result, b_result, c_result + 100]) + + +@pytest.mark.asyncio +async def test_raise_indirect(): + @durable + async def main(): + return await gather(call_one("a"), raises_error()) + + output = await start(main) + assert_exit_result_error(output, ValueError, "oops") + + +@pytest.mark.asyncio +async def test_raise_reset(): + @durable + async def main(x: int, y: int): + raise TailCall( + call=Call(function="main", input=Arguments((), {"x": x + 1, "y": y + 2})) ) + output = await start(main, x=1, y=2) + assert_exit_tail_call( + output, + tail_call=Call(function="main", input=Arguments((), {"x": 2, "y": 4})), + ) + + +@pytest.mark.asyncio +async def test_min_max_results_clamping(): + @durable + async def main(): + return await call_concurrently("a", "b", "c") + + output = await start(main, poll_min_results=1, poll_max_results=10) + assert_poll_call_functions(output, ["a", "b", "c"], min_results=1, max_results=3) + + output = await start(main, poll_min_results=1, poll_max_results=2) + assert_poll_call_functions(output, ["a", "b", "c"], min_results=1, max_results=2) + + output = await start(main, poll_min_results=10, poll_max_results=10) + assert_poll_call_functions(output, ["a", "b", "c"], min_results=3, max_results=3) + + +async def start( + main: Callable, + *args: Any, + poll_min_results=1, + poll_max_results=10, + poll_max_wait_seconds=None, + **kwargs: Any, +) -> Output: + input = Input.from_input_arguments(main.__qualname__, *args, **kwargs) + return await OneShotScheduler( + main, + poll_min_results=poll_min_results, + poll_max_results=poll_max_results, + poll_max_wait_seconds=poll_max_wait_seconds, + ).run(input) + + +async def resume( + main: Callable, + prev_output: Output, + call_results: List[CallResult], + poll_error: Optional[Exception] = None, +): + poll = assert_poll(prev_output) + input = Input.from_poll_results( + main.__qualname__, + poll.coroutine_state, + call_results, + Error.from_exception(poll_error) if poll_error else None, + ) + return await OneShotScheduler(main).run(input) + + +def assert_exit(output: Output) -> exit_pb.Exit: + response = output._message + assert response.HasField("exit") + assert not response.HasField("poll") + return response.exit + + +def assert_exit_result(output: Output) -> call_pb.CallResult: + exit = assert_exit(output) + assert exit.HasField("result") + assert not exit.HasField("tail_call") + return exit.result + + +def assert_exit_result_value(output: Output, expect: Any): + result = assert_exit_result(output) + assert result.HasField("output") + assert not result.HasField("error") + assert expect == any_unpickle(result.output) + + +<<<<<<< HEAD def resume( self, main: Callable, @@ -454,77 +508,60 @@ def resume( Error.from_exception(poll_error) if poll_error else None, ) return self.runner.run(OneShotScheduler(main).run(input)) - - def assert_exit(self, output: Output) -> exit_pb.Exit: - response = output._message - self.assertTrue(response.HasField("exit")) - self.assertFalse(response.HasField("poll")) - return response.exit - - def assert_exit_result(self, output: Output) -> call_pb.CallResult: - exit = self.assert_exit(output) - self.assertTrue(exit.HasField("result")) - self.assertFalse(exit.HasField("tail_call")) - return exit.result - - def assert_exit_result_value(self, output: Output, expect: Any): - result = self.assert_exit_result(output) - self.assertTrue(result.HasField("output")) - self.assertFalse(result.HasField("error")) - self.assertEqual(expect, any_unpickle(result.output)) - - def assert_exit_result_error( - self, output: Output, expect: Type[Exception], message: Optional[str] = None - ): - result = self.assert_exit_result(output) - self.assertFalse(result.HasField("output")) - self.assertTrue(result.HasField("error")) - - error = Error._from_proto(result.error).to_exception() - - self.assertEqual(error.__class__, expect) - if message is not None: - self.assertEqual(str(error), message) - return error - - def assert_exit_tail_call(self, output: Output, tail_call: Call): - exit = self.assert_exit(output) - self.assertFalse(exit.HasField("result")) - self.assertTrue(exit.HasField("tail_call")) - self.assertEqual(tail_call._as_proto(), exit.tail_call) - - def assert_poll(self, output: Output) -> poll_pb.Poll: - response = output._message - if response.HasField("exit"): - raise RuntimeError( - f"coroutine unexpectedly returned {response.exit.result}" - ) - self.assertTrue(response.HasField("poll")) - return response.poll - - def assert_empty_poll(self, output: Output): - poll = self.assert_poll(output) - self.assertEqual(len(poll.calls), 0) - - def assert_poll_call_functions( - self, output: Output, expect: List[str], min_results=None, max_results=None - ): - poll = self.assert_poll(output) - # Note: we're not testing endpoint/input here. - # Check function names match: - self.assertListEqual([c.function for c in poll.calls], expect) - # Check correlation IDs are unique. - correlation_ids = [c.correlation_id for c in poll.calls] - self.assertEqual( - len(correlation_ids), - len(set(correlation_ids)), - "correlation IDs were not unique", - ) - if min_results is not None: - self.assertEqual(min_results, poll.min_results) - if max_results is not None: - self.assertEqual(max_results, poll.max_results) - return poll.calls +======= +def assert_exit_result_error( + output: Output, expect: Type[Exception], message: Optional[str] = None +): + result = assert_exit_result(output) + assert not result.HasField("output") + assert result.HasField("error") +>>>>>>> 626d02d (aiohttp: refactor internals to use asyncio throughout the SDK) + + error = Error._from_proto(result.error).to_exception() + assert error.__class__ == expect + + if message is not None: + assert str(error) == message + return error + + +def assert_exit_tail_call(output: Output, tail_call: Call): + exit = assert_exit(output) + assert not exit.HasField("result") + assert exit.HasField("tail_call") + assert tail_call._as_proto() == exit.tail_call + + +def assert_poll(output: Output) -> poll_pb.Poll: + response = output._message + if response.HasField("exit"): + raise RuntimeError(f"coroutine unexpectedly returned {response.exit.result}") + assert response.HasField("poll") + return response.poll + + +def assert_empty_poll(output: Output): + poll = assert_poll(output) + assert len(poll.calls) == 0 + + +def assert_poll_call_functions( + output: Output, expect: List[str], min_results=None, max_results=None +): + poll = assert_poll(output) + # Note: we're not testing endpoint/input here. + # Check function names match: + assert [c.function for c in poll.calls] == expect + # Check correlation IDs are unique. + correlation_ids = [c.correlation_id for c in poll.calls] + assert len(correlation_ids) == len( + set(correlation_ids) + ), "correlation IDs were not unique" + if min_results is not None: + assert min_results == poll.min_results + if max_results is not None: + assert max_results == poll.max_results + return poll.calls class TestAllFuture(unittest.TestCase): diff --git a/tests/dispatch/test_status.py b/tests/dispatch/test_status.py index ec799ca5..d42b20d1 100644 --- a/tests/dispatch/test_status.py +++ b/tests/dispatch/test_status.py @@ -1,4 +1,3 @@ -import unittest from typing import Any from dispatch import error @@ -12,232 +11,261 @@ ) -class TestErrorStatus(unittest.TestCase): - def test_status_for_Exception(self): - assert status_for_error(Exception()) is Status.PERMANENT_ERROR +def test_status_for_Exception(): + assert status_for_error(Exception()) is Status.PERMANENT_ERROR - def test_status_for_ValueError(self): - assert status_for_error(ValueError()) is Status.INVALID_ARGUMENT - def test_status_for_TypeError(self): - assert status_for_error(TypeError()) is Status.INVALID_ARGUMENT +def test_status_for_ValueError(): + assert status_for_error(ValueError()) is Status.INVALID_ARGUMENT - def test_status_for_KeyError(self): - assert status_for_error(KeyError()) is Status.PERMANENT_ERROR - def test_status_for_EOFError(self): - assert status_for_error(EOFError()) is Status.TEMPORARY_ERROR +def test_status_for_TypeError(): + assert status_for_error(TypeError()) is Status.INVALID_ARGUMENT - def test_status_for_ConnectionError(self): - assert status_for_error(ConnectionError()) is Status.TCP_ERROR - def test_status_for_PermissionError(self): - assert status_for_error(PermissionError()) is Status.PERMISSION_DENIED +def test_status_for_KeyError(): + assert status_for_error(KeyError()) is Status.PERMANENT_ERROR - def test_status_for_FileNotFoundError(self): - assert status_for_error(FileNotFoundError()) is Status.NOT_FOUND - def test_status_for_InterruptedError(self): - assert status_for_error(InterruptedError()) is Status.TEMPORARY_ERROR +def test_status_for_EOFError(): + assert status_for_error(EOFError()) is Status.TEMPORARY_ERROR - def test_status_for_KeyboardInterrupt(self): - assert status_for_error(KeyboardInterrupt()) is Status.TEMPORARY_ERROR - def test_status_for_OSError(self): - assert status_for_error(OSError()) is Status.TEMPORARY_ERROR +def test_status_for_ConnectionError(): + assert status_for_error(ConnectionError()) is Status.TCP_ERROR - def test_status_for_TimeoutError(self): - assert status_for_error(TimeoutError()) is Status.TIMEOUT - def test_status_for_BaseException(self): - assert status_for_error(BaseException()) is Status.PERMANENT_ERROR +def test_status_for_PermissionError(): + assert status_for_error(PermissionError()) is Status.PERMISSION_DENIED - def test_status_for_custom_error(self): - class CustomError(Exception): - pass - assert status_for_error(CustomError()) is Status.PERMANENT_ERROR +def test_status_for_FileNotFoundError(): + assert status_for_error(FileNotFoundError()) is Status.NOT_FOUND - def test_status_for_custom_error_with_handler(self): - class CustomError(Exception): - pass - def handler(error: Exception) -> Status: - assert isinstance(error, CustomError) - return Status.OK +def test_status_for_InterruptedError(): + assert status_for_error(InterruptedError()) is Status.TEMPORARY_ERROR - register_error_type(CustomError, handler) - assert status_for_error(CustomError()) is Status.OK - def test_status_for_custom_error_with_base_handler(self): - class CustomBaseError(Exception): - pass +def test_status_for_KeyboardInterrupt(): + assert status_for_error(KeyboardInterrupt()) is Status.TEMPORARY_ERROR - class CustomError(CustomBaseError): - pass - def handler(error: Exception) -> Status: - assert isinstance(error, CustomBaseError) - assert isinstance(error, CustomError) - return Status.TCP_ERROR +def test_status_for_OSError(): + assert status_for_error(OSError()) is Status.TEMPORARY_ERROR - register_error_type(CustomBaseError, handler) - assert status_for_error(CustomError()) is Status.TCP_ERROR - def test_status_for_custom_error_with_status(self): - class CustomError(Exception): - pass +def test_status_for_TimeoutError(): + assert status_for_error(TimeoutError()) is Status.TIMEOUT - register_error_type(CustomError, Status.THROTTLED) - assert status_for_error(CustomError()) is Status.THROTTLED - def test_status_for_custom_error_with_base_status(self): - class CustomBaseError(Exception): - pass +def test_status_for_BaseException(): + assert status_for_error(BaseException()) is Status.PERMANENT_ERROR - class CustomError(CustomBaseError): - pass - class CustomError2(CustomBaseError): - pass +def test_status_for_custom_error(): + class CustomError(Exception): + pass - register_error_type(CustomBaseError, Status.THROTTLED) - register_error_type(CustomError2, Status.INVALID_ARGUMENT) - assert status_for_error(CustomError()) is Status.THROTTLED - assert status_for_error(CustomError2()) is Status.INVALID_ARGUMENT + assert status_for_error(CustomError()) is Status.PERMANENT_ERROR - def test_status_for_custom_timeout(self): - class CustomError(TimeoutError): - pass - assert status_for_error(CustomError()) is Status.TIMEOUT +def test_status_for_custom_error_with_handler(): + class CustomError(Exception): + pass - def test_status_for_DispatchError(self): - assert status_for_error(error.TimeoutError()) is Status.TIMEOUT - assert status_for_error(error.ThrottleError()) is Status.THROTTLED - assert status_for_error(error.InvalidArgumentError()) is Status.INVALID_ARGUMENT - assert status_for_error(error.InvalidResponseError()) is Status.INVALID_RESPONSE - assert status_for_error(error.TemporaryError()) is Status.TEMPORARY_ERROR - assert status_for_error(error.PermanentError()) is Status.PERMANENT_ERROR - assert ( - status_for_error(error.IncompatibleStateError()) - is Status.INCOMPATIBLE_STATE - ) - assert status_for_error(error.DNSError()) is Status.DNS_ERROR - assert status_for_error(error.TCPError()) is Status.TCP_ERROR - assert status_for_error(error.HTTPError()) is Status.HTTP_ERROR - assert status_for_error(error.UnauthenticatedError()) is Status.UNAUTHENTICATED - assert ( - status_for_error(error.PermissionDeniedError()) is Status.PERMISSION_DENIED - ) - assert status_for_error(error.NotFoundError()) is Status.NOT_FOUND - assert status_for_error(error.DispatchError()) is Status.UNSPECIFIED + def handler(error: Exception) -> Status: + assert isinstance(error, CustomError) + return Status.OK - def test_status_for_custom_output(self): - class CustomOutput: - pass + register_error_type(CustomError, handler) + assert status_for_error(CustomError()) is Status.OK - assert status_for_output(CustomOutput()) is Status.OK # default - def test_status_for_custom_output_with_handler(self): - class CustomOutput: - pass +def test_status_for_custom_error_with_base_handler(): + class CustomBaseError(Exception): + pass - def handler(output: Any) -> Status: - assert isinstance(output, CustomOutput) + class CustomError(CustomBaseError): + pass + + def handler(error: Exception) -> Status: + assert isinstance(error, CustomBaseError) + assert isinstance(error, CustomError) + return Status.TCP_ERROR + + register_error_type(CustomBaseError, handler) + assert status_for_error(CustomError()) is Status.TCP_ERROR + + +def test_status_for_custom_error_with_status(): + class CustomError(Exception): + pass + + register_error_type(CustomError, Status.THROTTLED) + assert status_for_error(CustomError()) is Status.THROTTLED + + +def test_status_for_custom_error_with_base_status(): + class CustomBaseError(Exception): + pass + + class CustomError(CustomBaseError): + pass + + class CustomError2(CustomBaseError): + pass + + register_error_type(CustomBaseError, Status.THROTTLED) + register_error_type(CustomError2, Status.INVALID_ARGUMENT) + assert status_for_error(CustomError()) is Status.THROTTLED + assert status_for_error(CustomError2()) is Status.INVALID_ARGUMENT + + +def test_status_for_custom_timeout(): + class CustomError(TimeoutError): + pass + + assert status_for_error(CustomError()) is Status.TIMEOUT + + +def test_status_for_DispatchError(): + assert status_for_error(error.TimeoutError()) is Status.TIMEOUT + assert status_for_error(error.ThrottleError()) is Status.THROTTLED + assert status_for_error(error.InvalidArgumentError()) is Status.INVALID_ARGUMENT + assert status_for_error(error.InvalidResponseError()) is Status.INVALID_RESPONSE + assert status_for_error(error.TemporaryError()) is Status.TEMPORARY_ERROR + assert status_for_error(error.PermanentError()) is Status.PERMANENT_ERROR + assert status_for_error(error.IncompatibleStateError()) is Status.INCOMPATIBLE_STATE + assert status_for_error(error.DNSError()) is Status.DNS_ERROR + assert status_for_error(error.TCPError()) is Status.TCP_ERROR + assert status_for_error(error.HTTPError()) is Status.HTTP_ERROR + assert status_for_error(error.UnauthenticatedError()) is Status.UNAUTHENTICATED + assert status_for_error(error.PermissionDeniedError()) is Status.PERMISSION_DENIED + assert status_for_error(error.NotFoundError()) is Status.NOT_FOUND + assert status_for_error(error.DispatchError()) is Status.UNSPECIFIED + + +def test_status_for_custom_output(): + class CustomOutput: + pass + + assert status_for_output(CustomOutput()) is Status.OK # default + + +def test_status_for_custom_output_with_handler(): + class CustomOutput: + pass + + def handler(output: Any) -> Status: + assert isinstance(output, CustomOutput) + return Status.DNS_ERROR + + register_output_type(CustomOutput, handler) + assert status_for_output(CustomOutput()) is Status.DNS_ERROR + + +def test_status_for_custom_output_with_base_handler(): + class CustomOutputBase: + pass + + class CustomOutputError(CustomOutputBase): + pass + + class CustomOutputSuccess(CustomOutputBase): + pass + + def handler(output: Any) -> Status: + assert isinstance(output, CustomOutputBase) + if isinstance(output, CustomOutputError): return Status.DNS_ERROR + assert isinstance(output, CustomOutputSuccess) + return Status.OK + + register_output_type(CustomOutputBase, handler) + assert status_for_output(CustomOutputSuccess()) is Status.OK + assert status_for_output(CustomOutputError()) is Status.DNS_ERROR + + +def test_status_for_custom_output_with_status(): + class CustomOutputBase: + pass + + class CustomOutputChild1(CustomOutputBase): + pass + + class CustomOutputChild2(CustomOutputBase): + pass + + register_output_type(CustomOutputBase, Status.PERMISSION_DENIED) + register_output_type(CustomOutputChild1, Status.TCP_ERROR) + assert status_for_output(CustomOutputChild1()) is Status.TCP_ERROR + assert status_for_output(CustomOutputChild2()) is Status.PERMISSION_DENIED + + +def test_status_for_custom_output_with_base_status(): + class CustomOutput(Exception): + pass + + register_output_type(CustomOutput, Status.THROTTLED) + assert status_for_output(CustomOutput()) is Status.THROTTLED - register_output_type(CustomOutput, handler) - assert status_for_output(CustomOutput()) is Status.DNS_ERROR - def test_status_for_custom_output_with_base_handler(self): - class CustomOutputBase: - pass +def test_http_response_code_status_400(): + assert http_response_code_status(400) is Status.INVALID_ARGUMENT - class CustomOutputError(CustomOutputBase): - pass - class CustomOutputSuccess(CustomOutputBase): - pass +def test_http_response_code_status_401(): + assert http_response_code_status(401) is Status.UNAUTHENTICATED - def handler(output: Any) -> Status: - assert isinstance(output, CustomOutputBase) - if isinstance(output, CustomOutputError): - return Status.DNS_ERROR - assert isinstance(output, CustomOutputSuccess) - return Status.OK - register_output_type(CustomOutputBase, handler) - assert status_for_output(CustomOutputSuccess()) is Status.OK - assert status_for_output(CustomOutputError()) is Status.DNS_ERROR +def test_http_response_code_status_403(): + assert http_response_code_status(403) is Status.PERMISSION_DENIED - def test_status_for_custom_output_with_status(self): - class CustomOutputBase: - pass - class CustomOutputChild1(CustomOutputBase): - pass +def test_http_response_code_status_404(): + assert http_response_code_status(404) is Status.NOT_FOUND - class CustomOutputChild2(CustomOutputBase): - pass - register_output_type(CustomOutputBase, Status.PERMISSION_DENIED) - register_output_type(CustomOutputChild1, Status.TCP_ERROR) - assert status_for_output(CustomOutputChild1()) is Status.TCP_ERROR - assert status_for_output(CustomOutputChild2()) is Status.PERMISSION_DENIED +def test_http_response_code_status_408(): + assert http_response_code_status(408) is Status.TIMEOUT - def test_status_for_custom_output_with_base_status(self): - class CustomOutput(Exception): - pass - register_output_type(CustomOutput, Status.THROTTLED) - assert status_for_output(CustomOutput()) is Status.THROTTLED +def test_http_response_code_status_429(): + assert http_response_code_status(429) is Status.THROTTLED -class TestHTTPStatusCodes(unittest.TestCase): - def test_http_response_code_status_400(self): - assert http_response_code_status(400) is Status.INVALID_ARGUMENT +def test_http_response_code_status_501(): + assert http_response_code_status(501) is Status.PERMANENT_ERROR - def test_http_response_code_status_401(self): - assert http_response_code_status(401) is Status.UNAUTHENTICATED - def test_http_response_code_status_403(self): - assert http_response_code_status(403) is Status.PERMISSION_DENIED +def test_http_response_code_status_1xx(): + for status in range(100, 200): + assert http_response_code_status(100) is Status.PERMANENT_ERROR - def test_http_response_code_status_404(self): - assert http_response_code_status(404) is Status.NOT_FOUND - def test_http_response_code_status_408(self): - assert http_response_code_status(408) is Status.TIMEOUT +def test_http_response_code_status_2xx(): + for status in range(200, 300): + assert http_response_code_status(200) is Status.OK - def test_http_response_code_status_429(self): - assert http_response_code_status(429) is Status.THROTTLED - def test_http_response_code_status_501(self): - assert http_response_code_status(501) is Status.PERMANENT_ERROR +def test_http_response_code_status_3xx(): + for status in range(300, 400): + assert http_response_code_status(300) is Status.PERMANENT_ERROR - def test_http_response_code_status_1xx(self): - for status in range(100, 200): - assert http_response_code_status(100) is Status.PERMANENT_ERROR - def test_http_response_code_status_2xx(self): - for status in range(200, 300): - assert http_response_code_status(200) is Status.OK +def test_http_response_code_status_4xx(): + for status in range(400, 500): + if status not in (400, 401, 403, 404, 408, 429, 501): + assert http_response_code_status(status) is Status.PERMANENT_ERROR - def test_http_response_code_status_3xx(self): - for status in range(300, 400): - assert http_response_code_status(300) is Status.PERMANENT_ERROR - def test_http_response_code_status_4xx(self): - for status in range(400, 500): - if status not in (400, 401, 403, 404, 408, 429, 501): - assert http_response_code_status(status) is Status.PERMANENT_ERROR +def test_http_response_code_status_5xx(): + for status in range(500, 600): + if status not in (501,): + assert http_response_code_status(status) is Status.TEMPORARY_ERROR - def test_http_response_code_status_5xx(self): - for status in range(500, 600): - if status not in (501,): - assert http_response_code_status(status) is Status.TEMPORARY_ERROR - def test_http_response_code_status_6xx(self): - for status in range(600, 700): - assert http_response_code_status(600) is Status.UNSPECIFIED +def test_http_response_code_status_6xx(): + for status in range(600, 700): + assert http_response_code_status(600) is Status.UNSPECIFIED diff --git a/tests/test_client.py b/tests/test_client.py index c04945b2..70e754d2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,104 +1,59 @@ import os -import unittest from unittest import mock -import httpx +import pytest -import dispatch.test.httpx -from dispatch import Call, Client -from dispatch.proto import _any_unpickle as any_unpickle -from dispatch.test import DispatchServer, DispatchService, EndpointClient +import dispatch.test +from dispatch import Call +from dispatch.test import Client -class TestClient(unittest.TestCase): - def setUp(self): - http_client = dispatch.test.httpx.Client( - httpx.Client(base_url="http://function-service") - ) - endpoint_client = EndpointClient(http_client) +def server() -> dispatch.test.Server: + return dispatch.test.Server(dispatch.test.Service()) - api_key = "0000000000000000" - self.dispatch_service = DispatchService(endpoint_client, api_key) - self.dispatch_server = DispatchServer(self.dispatch_service) - self.dispatch_client = Client(api_key, api_url=self.dispatch_server.url) - self.dispatch_server.start() +@mock.patch.dict(os.environ, {"DISPATCH_API_KEY": ""}) +def test_api_key_missing(): + with pytest.raises(ValueError) as mc: + Client() + assert ( + str(mc.value) + == "missing API key: set it with the DISPATCH_API_KEY environment variable" + ) - def tearDown(self): - self.dispatch_server.stop() - @mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "WHATEVER"}) - def test_api_key_from_env(self): - client = Client(api_url=self.dispatch_server.url) +def test_url_bad_scheme(): + with pytest.raises(ValueError) as mc: + Client(api_url="ftp://example.com", api_key="foo") + assert str(mc.value) == "Invalid API scheme: 'ftp'" - with self.assertRaisesRegex( + +def test_can_be_constructed_on_https(): + # Goal is to not raise an exception here. We don't have an HTTPS server + # around to actually test this. + Client(api_url="https://example.com", api_key="foo") + + +@mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "0000000000000000"}) +@pytest.mark.asyncio +async def test_api_key_from_env(): + async with server() as api: + client = Client(api_url=api.url) + + with pytest.raises( PermissionError, - r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", + match=r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", ) as mc: - client.dispatch([Call(function="my-function", input=42)]) + await client.dispatch([Call(function="my-function", input=42)]) + - def test_api_key_from_arg(self): - client = Client(api_url=self.dispatch_server.url, api_key="WHATEVER") +@pytest.mark.asyncio +async def test_api_key_from_arg(): + async with server() as api: + client = Client(api_url=api.url, api_key="WHATEVER") - with self.assertRaisesRegex( + with pytest.raises( PermissionError, - r"Dispatch received an invalid authentication token \(check api_key is correct\)", + match=r"Dispatch received an invalid authentication token \(check api_key is correct\)", ) as mc: - client.dispatch([Call(function="my-function", input=42)]) - - @mock.patch.dict(os.environ, {"DISPATCH_API_KEY": ""}) - def test_api_key_missing(self): - with self.assertRaises(ValueError) as mc: - Client() - self.assertEqual( - str(mc.exception), - "missing API key: set it with the DISPATCH_API_KEY environment variable", - ) - - def test_url_bad_scheme(self): - with self.assertRaises(ValueError) as mc: - Client(api_url="ftp://example.com", api_key="foo") - self.assertEqual(str(mc.exception), "Invalid API scheme: 'ftp'") - - def test_can_be_constructed_on_https(self): - # Goal is to not raise an exception here. We don't have an HTTPS server - # around to actually test this. - Client(api_url="https://example.com", api_key="foo") - - def test_call_pickle(self): - dispatch_ids = self.dispatch_client.dispatch( - [Call(function="my-function", input=42)] - ) - self.assertEqual(len(dispatch_ids), 1) - - pending_calls = self.dispatch_service.queue - self.assertEqual(len(pending_calls), 1) - dispatch_id, call, _ = pending_calls[0] - self.assertEqual(dispatch_id, dispatch_ids[0]) - self.assertEqual(call.function, "my-function") - self.assertEqual(any_unpickle(call.input), 42) - - def test_batch(self): - batch = self.dispatch_client.batch() - - batch.add_call(Call(function="my-function", input=42)).add_call( - Call(function="my-function", input=23) - ).add_call(Call(function="my-function2", input=11)) - - dispatch_ids = batch.dispatch() - self.assertEqual(len(dispatch_ids), 3) - - pending_calls = self.dispatch_service.queue - self.assertEqual(len(pending_calls), 3) - dispatch_id0, call0, _ = pending_calls[0] - dispatch_id1, call1, _ = pending_calls[1] - dispatch_id2, call2, _ = pending_calls[2] - self.assertListEqual(dispatch_ids, [dispatch_id0, dispatch_id1, dispatch_id2]) - self.assertEqual(call0.function, "my-function") - self.assertEqual(any_unpickle(call0.input), 42) - self.assertEqual(call1.function, "my-function") - self.assertEqual(any_unpickle(call1.input), 23) - self.assertEqual(call2.function, "my-function2") - self.assertEqual(any_unpickle(call2.input), 11) - - self.assertEqual(len(batch.dispatch()), 0) # batch was reset + await client.dispatch([Call(function="my-function", input=42)]) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index e7f4e528..68e22b56 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,3 +1,4 @@ +import asyncio import base64 import os import pickle @@ -30,7 +31,7 @@ public_key_from_pem, ) from dispatch.status import Status -from dispatch.test import DispatchServer, DispatchService, EndpointClient +from dispatch.test import Client, DispatchServer, DispatchService, EndpointClient from dispatch.test.fastapi import http_client @@ -134,9 +135,7 @@ def setUp(self): endpoint_client, api_key, collect_roundtrips=True ) self.dispatch_server = DispatchServer(self.dispatch_service) - self.dispatch_client = dispatch.Client( - api_key, api_url=self.dispatch_server.url - ) + self.dispatch_client = Client(api_key, api_url=self.dispatch_server.url) self.dispatch = Dispatch( self.endpoint_app, @@ -157,11 +156,13 @@ def test_simple_end_to_end(self): def my_function(name: str) -> str: return f"Hello world: {name}" - call = my_function.build_call(52) + call = my_function.build_call("52") self.assertEqual(call.function.split(".")[-1], "my_function") # The client. - [dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)]) + [dispatch_id] = asyncio.run( + self.dispatch_client.dispatch([my_function.build_call("52")]) + ) # Simulate execution for testing purposes. self.dispatch_service.dispatch_calls() @@ -177,10 +178,10 @@ def test_simple_missing_signature(self): async def my_function(name: str) -> str: return f"Hello world: {name}" - call = my_function.build_call(52) + call = my_function.build_call("52") self.assertEqual(call.function.split(".")[-1], "my_function") - [dispatch_id] = self.dispatch_client.dispatch([call]) + [dispatch_id] = asyncio.run(self.dispatch_client.dispatch([call])) self.dispatch_service.endpoint_client = create_endpoint_client( self.endpoint_app @@ -548,17 +549,3 @@ def get(path: str) -> httpx.Response: http_response = any_unpickle(resp.exit.result.output) self.assertEqual("application/json", http_response.headers["content-type"]) self.assertEqual('"OK"', http_response.text) - - -class TestError(unittest.TestCase): - def test_error_with_ok_status(self): - with self.assertRaises(ValueError): - Error(Status.OK, type="type", message="yep") - - def test_from_exception_timeout(self): - err = Error.from_exception(TimeoutError()) - self.assertEqual(Status.TIMEOUT, err.status) - - def test_from_exception_syntax_error(self): - err = Error.from_exception(SyntaxError()) - self.assertEqual(Status.PERMANENT_ERROR, err.status) diff --git a/tests/test_flask.py b/tests/test_flask.py index ae6d4312..b78c56be 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -1,3 +1,4 @@ +import asyncio import base64 import os import pickle @@ -28,7 +29,7 @@ public_key_from_pem, ) from dispatch.status import Status -from dispatch.test import DispatchServer, DispatchService, EndpointClient +from dispatch.test import Client, DispatchServer, DispatchService, EndpointClient from dispatch.test.flask import http_client @@ -98,9 +99,7 @@ def setUp(self): endpoint_client, api_key, collect_roundtrips=True ) self.dispatch_server = DispatchServer(self.dispatch_service) - self.dispatch_client = dispatch.Client( - api_key, api_url=self.dispatch_server.url - ) + self.dispatch_client = Client(api_key, api_url=self.dispatch_server.url) self.dispatch = Dispatch( self.endpoint_app, @@ -121,11 +120,13 @@ def test_simple_end_to_end(self): def my_function(name: str) -> str: return f"Hello world: {name}" - call = my_function.build_call(52) + call = my_function.build_call("52") self.assertEqual(call.function.split(".")[-1], "my_function") # The client. - [dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)]) + [dispatch_id] = asyncio.run( + self.dispatch_client.dispatch([my_function.build_call("52")]) + ) # Simulate execution for testing purposes. self.dispatch_service.dispatch_calls() diff --git a/tests/test_http.py b/tests/test_http.py index 56aa8e38..34d27ab4 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,105 +1,57 @@ -import base64 -import os -import pickle -import struct -import threading -import unittest +import asyncio +import socket from http.server import HTTPServer -from typing import Any, Tuple -from unittest import mock -import fastapi -import google.protobuf.any_pb2 -import google.protobuf.wrappers_pb2 -import httpx -from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey +import dispatch.test +from dispatch.asyncio import Runner +from dispatch.function import Registry +from dispatch.http import Dispatch, FunctionService, Server -import dispatch.test.httpx -from dispatch.experimental.durable.registry import clear_functions -from dispatch.function import Arguments, Error, Function, Input, Output, Registry -from dispatch.http import Dispatch -from dispatch.proto import _any_unpickle as any_unpickle -from dispatch.proto import _pb_any_pickle as any_pickle -from dispatch.sdk.v1 import call_pb2 as call_pb -from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.signature import parse_verification_key, public_key_from_pem -from dispatch.status import Status -from dispatch.test import EndpointClient -public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----" -public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----" -public_key = public_key_from_pem(public_key_pem) -public_key_bytes = public_key.public_bytes_raw() -public_key_b64 = base64.b64encode(public_key_bytes) +class TestHTTP(dispatch.test.TestCase): -from datetime import datetime + def dispatch_test_init(self, reg: Registry) -> str: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + sock.listen(128) + (host, port) = sock.getsockname() -class TestHTTP(unittest.TestCase): - def setUp(self): - host = "127.0.0.1" - port = 9999 - - self.server_address = (host, port) - self.endpoint = f"http://{host}:{port}" - self.dispatch = Dispatch( - Registry( - endpoint=self.endpoint, - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", - ), + self.httpserver = HTTPServer( + server_address=(host, port), + RequestHandlerClass=Dispatch(reg), + bind_and_activate=False, ) + self.httpserver.socket = sock + return f"http://{host}:{port}" - self.client = httpx.Client(timeout=1.0) - self.server = HTTPServer(self.server_address, self.dispatch) - self.thread = threading.Thread( - target=lambda: self.server.serve_forever(poll_interval=0.05) - ) - self.thread.start() - - def tearDown(self): - self.server.shutdown() - self.thread.join(timeout=1.0) - self.client.close() - self.server.server_close() - - def test_content_length_missing(self): - resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run") - body = resp.read() - self.assertEqual(resp.status_code, 400) - self.assertEqual( - body, b'{"code":"invalid_argument","message":"content length is required"}' - ) + def dispatch_test_run(self): + self.httpserver.serve_forever(poll_interval=0.05) - def test_content_length_too_large(self): - resp = self.client.post( - f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run", - data={"msg": "a" * 16_000_001}, - ) - body = resp.read() - self.assertEqual(resp.status_code, 400) - self.assertEqual( - body, b'{"code":"invalid_argument","message":"content length is too large"}' - ) + def dispatch_test_stop(self): + self.httpserver.shutdown() + self.httpserver.server_close() + self.httpserver.socket.close() - def test_simple_request(self): - @self.dispatch.registry.primitive_function - async def my_function(input: Input) -> Output: - return Output.value( - f"You told me: '{input.input}' ({len(input.input)} characters)" - ) - http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) - client = EndpointClient(http_client) +class TestAIOHTTP(dispatch.test.TestCase): - req = function_pb.RunRequest( - function=my_function.name, input=any_pickle("Hello World!") - ) + def dispatch_test_init(self, reg: Registry) -> str: + host = "127.0.0.1" + port = 0 - resp = client.run(req) + self.aiowait = asyncio.Event() + self.aioloop = Runner() + self.aiohttp = Server(host, port, Dispatch(reg)) + self.aioloop.run(self.aiohttp.start()) - self.assertIsInstance(resp, function_pb.RunResponse) + return f"http://{self.aiohttp.host}:{self.aiohttp.port}" - output = any_unpickle(resp.exit.result.output) + def dispatch_test_run(self): + self.aioloop.run(self.aiowait.wait()) + self.aioloop.run(self.aiohttp.stop()) + self.aioloop.close() - self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") + def dispatch_test_stop(self): + self.aioloop.get_loop().call_soon_threadsafe(self.aiowait.set) From 935ffdcf43b4083187900bf0bf42ba3cebf336c2 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 11 Jun 2024 16:19:02 -0700 Subject: [PATCH 02/34] all tests pass using new APIs Signed-off-by: Achille Roussel --- src/dispatch/flask.py | 24 +- src/dispatch/function.py | 10 +- src/dispatch/test/__init__.py | 164 ++++-- tests/test_fastapi.py | 919 +++++++++++++++------------------- tests/test_flask.py | 142 +----- 5 files changed, 565 insertions(+), 694 deletions(-) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 5f48dd1e..b65270a4 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -17,12 +17,12 @@ def read_root(): my_function.dispatch() """ +import asyncio import logging from typing import Optional, Union from flask import Flask, make_response, request -from dispatch.asyncio import Runner from dispatch.function import Registry from dispatch.http import FunctionServiceError, function_service_run from dispatch.signature import Ed25519PublicKey, parse_verification_key @@ -81,7 +81,6 @@ def __init__( ) app.errorhandler(FunctionServiceError)(self._handle_error) - app.post("/dispatch.sdk.v1.FunctionService/Run")(self._execute) def _handle_error(self, exc: FunctionServiceError): @@ -90,17 +89,16 @@ def _handle_error(self, exc: FunctionServiceError): def _execute(self): data: bytes = request.get_data(cache=False) - with Runner() as runner: - content = runner.run( - function_service_run( - request.url, - request.method, - dict(request.headers), - data, - self, - self._verification_key, - ), - ) + content = asyncio.run( + function_service_run( + request.url, + request.method, + dict(request.headers), + data, + self, + self._verification_key, + ), + ) res = make_response(content) res.content_type = "application/proto" diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 75302c78..f5807a15 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -488,9 +488,13 @@ def __init__( @classmethod def from_response(cls, status: int, body: bytes) -> ClientError: - error_dict = json.loads(body) - error_code = str(error_dict.get("code")) or "unknown" - error_message = str(error_dict.get("message")) or "unknown" + try: + error_dict = json.loads(body) + error_code = str(error_dict.get("code")) or "unknown" + error_message = str(error_dict.get("message")) or "unknown" + except json.JSONDecodeError: + error_code = "unknown" + error_message = str(body) return cls(status, error_code, error_message) diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index f8179324..0f07e521 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -1,4 +1,5 @@ import asyncio +import threading import unittest from datetime import datetime, timedelta from functools import wraps @@ -7,9 +8,11 @@ import aiohttp from aiohttp import web from google.protobuf.timestamp_pb2 import Timestamp +from typing_extensions import ParamSpec import dispatch.experimental.durable.registry from dispatch.function import Client as BaseClient +from dispatch.function import ClientError from dispatch.function import Registry as BaseRegistry from dispatch.http import Dispatch from dispatch.http import Server as BaseServer @@ -41,6 +44,7 @@ "DISPATCH_API_KEY", ] +P = ParamSpec("P") R = TypeVar("R", bound=BaseRegistry) T = TypeVar("T") @@ -79,8 +83,9 @@ def url(self): class Service(web.Application): tasks: dict[str, asyncio.Task[CallResult]] + _session: Optional[aiohttp.ClientSession] = None - def __init__(self): + def __init__(self, session: Optional[aiohttp.ClientSession] = None): super().__init__() self.dispatch_ids = (str(i) for i in range(2**32 - 1)) self.tasks = {} @@ -97,6 +102,9 @@ def __init__(self): ] ) + async def close(self): + await self.session.close() + async def authenticate(self, request: web.Request): auth = request.headers.get("Authorization") if not auth or not auth.startswith("Bearer "): @@ -109,8 +117,7 @@ async def authenticate(self, request: web.Request): async def handle_dispatch_request(self, request: web.Request): await self.authenticate(request) req = DispatchRequest.FromString(await request.read()) - async with aiohttp.ClientSession() as session: - res = await self.dispatch(session, req) + res = await self.dispatch(req) return web.Response( content_type="application/proto", body=res.SerializeToString() ) @@ -123,25 +130,24 @@ async def handle_wait_request(self, request: web.Request): content_type="application/proto", body=res.SerializeToString() ) - async def dispatch( - self, session: aiohttp.ClientSession, req: DispatchRequest - ) -> DispatchResponse: + async def dispatch(self, req: DispatchRequest) -> DispatchResponse: dispatch_ids = [next(self.dispatch_ids) for _ in req.calls] for call, dispatch_id in zip(req.calls, dispatch_ids): self.tasks[dispatch_id] = asyncio.create_task( - self.call(session, call, dispatch_id) + self.call(call, dispatch_id), + name=f"dispatch:{dispatch_id}", ) return DispatchResponse(dispatch_ids=dispatch_ids) # TODO: add to protobuf definitions async def wait(self, dispatch_id: str) -> CallResult: - return await self.tasks[dispatch_id] + task = self.tasks[dispatch_id] + return await task async def call( self, - session: aiohttp.ClientSession, call: Call, dispatch_id: str, parent_dispatch_id: Optional[str] = None, @@ -165,19 +171,20 @@ async def call( expiration_time = Timestamp() expiration_time.FromDatetime(exp) - req = RunRequest( - function=call.function, - input=call.input, - creation_time=creation_time, - expiration_time=expiration_time, - dispatch_id=dispatch_id, - parent_dispatch_id=parent_dispatch_id, - root_dispatch_id=root_dispatch_id, - ) + def make_request(call: Call) -> RunRequest: + return RunRequest( + function=call.function, + input=call.input, + creation_time=creation_time, + expiration_time=expiration_time, + dispatch_id=dispatch_id, + parent_dispatch_id=parent_dispatch_id, + root_dispatch_id=root_dispatch_id, + ) - endpoint = call.endpoint + req = make_request(call) while True: - res = await self.run(session, endpoint, req) + res = await self.run(call.endpoint, req) if res.status != STATUS_OK: # TODO: emulate retries etc... @@ -186,44 +193,60 @@ async def call( error=Error(type="status", message=str(res.status)), ) - if res.exit: - if res.exit.tail_call: - req.function = res.exit.tail_call.function - req.input = res.exit.tail_call.input - req.poll_result = None # type: ignore + if res.HasField("exit"): + if res.exit.HasField("tail_call"): + req = make_request(res.exit.tail_call) continue + result = res.exit.result return CallResult( dispatch_id=dispatch_id, - output=res.exit.result.output, - error=res.exit.result.error, + output=result.output if result.HasField("output") else None, + error=result.error if result.HasField("error") else None, ) - for call in res.poll.calls: - if not call.endpoint: - call.endpoint = endpoint - # TODO: enforce poll limits - req.input = None # type: ignore - req.poll_result = PollResult( - coroutine_state=res.poll.coroutine_state, - results=await asyncio.gather( - *[ - self.call(session, call, dispatch_id) - for call, dispatch_id in zip( - res.poll.calls, next(self.dispatch_ids) - ) - ] + results = await asyncio.gather( + *[ + self.call( + call=subcall, + dispatch_id=subcall_dispatch_id, + parent_dispatch_id=dispatch_id, + root_dispatch_id=root_dispatch_id, + ) + for subcall, subcall_dispatch_id in zip( + res.poll.calls, next(self.dispatch_ids) + ) + ] + ) + + req = RunRequest( + function=req.function, + creation_time=creation_time, + expiration_time=expiration_time, + dispatch_id=dispatch_id, + parent_dispatch_id=parent_dispatch_id, + root_dispatch_id=root_dispatch_id, + poll_result=PollResult( + coroutine_state=res.poll.coroutine_state, + results=results, ), ) - async def run( - self, session: aiohttp.ClientSession, endpoint: str, req: RunRequest - ) -> RunResponse: - async with await session.post( + async def run(self, endpoint: str, req: RunRequest) -> RunResponse: + async with await self.session.post( f"{endpoint}/dispatch.sdk.v1.FunctionService/Run", data=req.SerializeToString(), ) as response: - return RunResponse.FromString(await response.read()) + data = await response.read() + if response.status != 200: + raise ClientError.from_response(response.status, data) + return RunResponse.FromString(data) + + @property + def session(self) -> aiohttp.ClientSession: + if not self._session: + self._session = aiohttp.ClientSession() + return self._session async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: @@ -238,6 +261,7 @@ async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: reg.endpoint = server.url await fn(reg) finally: + await api.close() # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. dispatch.experimental.durable.registry.clear_functions() @@ -263,3 +287,51 @@ def wrapper(self: T): return run(Registry(), lambda reg: fn(self, reg)) return wrapper + + +class TestCase(unittest.TestCase): + + def dispatch_test_init(self, api_key: str, api_url: str) -> BaseRegistry: + raise NotImplementedError + + def dispatch_test_run(self): + raise NotImplementedError + + def dispatch_test_stop(self): + raise NotImplementedError + + def setUp(self): + self.service = Service() + self.server = Server(self.service) + self.loop = asyncio.new_event_loop() + self.loop.run_until_complete(self.server.start()) + + self.dispatch = self.dispatch_test_init( + api_key=DISPATCH_API_KEY, api_url=self.server.url + ) + self.dispatch.client = Client( + api_key=self.dispatch.client.api_key.value, + api_url=self.dispatch.client.api_url.value, + ) + + self.thread = threading.Thread(target=self.dispatch_test_run) + self.thread.start() + + def tearDown(self): + self.dispatch_test_stop() + self.thread.join() + + self.loop.run_until_complete(self.service.close()) + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() + + def test_simple_end_to_end(self): + @self.dispatch.function + def my_function(name: str) -> str: + return f"Hello world: {name}" + + async def test(): + msg = await my_function("52") + assert msg == "Hello world: 52" + + self.loop.run_until_complete(test()) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 68e22b56..7eceb5ff 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,28 +1,33 @@ import asyncio -import base64 -import os -import pickle -import struct -import unittest +import socket from typing import Any, Optional -from unittest import mock import fastapi import google.protobuf.any_pb2 import google.protobuf.wrappers_pb2 import httpx +import uvicorn from cryptography.hazmat.primitives.asymmetric.ed25519 import ( Ed25519PrivateKey, Ed25519PublicKey, ) +from fastapi import FastAPI from fastapi.testclient import TestClient import dispatch +from dispatch.asyncio import Runner from dispatch.experimental.durable.registry import clear_functions from dispatch.fastapi import Dispatch -from dispatch.function import Arguments, Error, Function, Input, Output +from dispatch.function import ( + Arguments, + Client, + Error, + Function, + Input, + Output, + Registry, +) from dispatch.proto import _any_unpickle as any_unpickle -from dispatch.proto import _pb_any_pickle as any_pickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -31,521 +36,423 @@ public_key_from_pem, ) from dispatch.status import Status -from dispatch.test import Client, DispatchServer, DispatchService, EndpointClient +from dispatch.test import EndpointClient from dispatch.test.fastapi import http_client -def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): - return Dispatch( - app, - endpoint=endpoint, - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", - ) - - -def create_endpoint_client( - app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None -): - return EndpointClient(http_client(app), signing_key) - - -class TestFastAPI(unittest.TestCase): - def test_Dispatch(self): - app = fastapi.FastAPI() - create_dispatch_instance(app, "https://127.0.0.1:9999") - - @app.get("/") - def read_root(): - return {"Hello": "World"} - - client = TestClient(app) - - # Ensure existing routes are still working. - resp = client.get("/") - self.assertEqual(resp.status_code, 200) +class TestFastAPI(dispatch.test.TestCase): - self.assertEqual(resp.json(), {"Hello": "World"}) + def dispatch_test_init(self, reg: Registry) -> str: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("127.0.0.1", 0)) + sock.listen(128) - # Ensure Dispatch root is available. - resp = client.post("/dispatch.sdk.v1.FunctionService/Run") - self.assertEqual(resp.status_code, 400) + (host, port) = sock.getsockname() - @mock.patch.dict(os.environ, {"DISPATCH_ENDPOINT_URL": ""}) - def test_Dispatch_no_endpoint(self): - app = fastapi.FastAPI() - with self.assertRaises(ValueError): - create_dispatch_instance(app, endpoint="") + app = FastAPI() + dispatch = Dispatch(app, registry=reg) - def test_Dispatch_endpoint_no_scheme(self): - app = fastapi.FastAPI() - with self.assertRaises(ValueError): - create_dispatch_instance(app, endpoint="127.0.0.1:9999") + config = uvicorn.Config(app, host=host, port=port) + self.sockets = [sock] + self.uvicorn = uvicorn.Server(config) + self.runner = Runner() + self.event = asyncio.Event() + return f"http://{host}:{port}" - def test_fastapi_simple_request(self): - app = fastapi.FastAPI() - dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/") + def dispatch_test_run(self): + loop = self.runner.get_loop() + loop.create_task(self.uvicorn.serve(self.sockets)) + self.runner.run(self.event.wait()) + self.runner.close() - @dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value( - f"You told me: '{input.input}' ({len(input.input)} characters)" - ) + for sock in self.sockets: + sock.close() - client = create_endpoint_client(app) + def dispatch_test_stop(self): + loop = self.runner.get_loop() + loop.call_soon_threadsafe(self.event.set) - req = function_pb.RunRequest( - function=my_function.name, - input=any_pickle("Hello World!"), - ) - resp = client.run(req) - - self.assertIsInstance(resp, function_pb.RunResponse) - - output = any_unpickle(resp.exit.result.output) - - self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") - - -signing_key = private_key_from_pem( - """ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF ------END PRIVATE KEY----- -""" -) - -verification_key = public_key_from_pem( - """ ------BEGIN PUBLIC KEY----- -MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs= ------END PUBLIC KEY----- -""" -) +def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): + return Dispatch( + app, + registry=Registry( + name=__name__, + endpoint=endpoint, + client=Client( + api_key="0000000000000000", + api_url="http://127.0.0.1:10000", + ), + ), + ) -class TestFullFastapi(unittest.TestCase): - def setUp(self): - self.endpoint_app = fastapi.FastAPI() - endpoint_client = create_endpoint_client(self.endpoint_app, signing_key) - - api_key = "0000000000000000" - self.dispatch_service = DispatchService( - endpoint_client, api_key, collect_roundtrips=True - ) - self.dispatch_server = DispatchServer(self.dispatch_service) - self.dispatch_client = Client(api_key, api_url=self.dispatch_server.url) - - self.dispatch = Dispatch( - self.endpoint_app, - endpoint="http://function-service", # unused - verification_key=verification_key, - api_key=api_key, - api_url=self.dispatch_server.url, - ) - - self.dispatch_server.start() - - def tearDown(self): - self.dispatch_server.stop() - - def test_simple_end_to_end(self): - # The FastAPI server. - @self.dispatch.function - def my_function(name: str) -> str: - return f"Hello world: {name}" - - call = my_function.build_call("52") - self.assertEqual(call.function.split(".")[-1], "my_function") - - # The client. - [dispatch_id] = asyncio.run( - self.dispatch_client.dispatch([my_function.build_call("52")]) - ) - - # Simulate execution for testing purposes. - self.dispatch_service.dispatch_calls() - - # Validate results. - roundtrips = self.dispatch_service.roundtrips[dispatch_id] - self.assertEqual(len(roundtrips), 1) - _, response = roundtrips[0] - self.assertEqual(any_unpickle(response.exit.result.output), "Hello world: 52") - - def test_simple_missing_signature(self): - @self.dispatch.function - async def my_function(name: str) -> str: - return f"Hello world: {name}" - - call = my_function.build_call("52") - self.assertEqual(call.function.split(".")[-1], "my_function") - - [dispatch_id] = asyncio.run(self.dispatch_client.dispatch([call])) - - self.dispatch_service.endpoint_client = create_endpoint_client( - self.endpoint_app - ) # no signing key - try: - self.dispatch_service.dispatch_calls() - except httpx.HTTPStatusError as e: - assert e.response.status_code == 403 - assert e.response.json() == { - "code": "permission_denied", - "message": 'Expected "Signature-Input" header field to be present', - } - else: - assert False, "Expected HTTPStatusError" +def create_endpoint_client( + app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None +): + return EndpointClient(http_client(app), signing_key) def response_output(resp: function_pb.RunResponse) -> Any: return any_unpickle(resp.exit.result.output) -class TestCoroutine(unittest.TestCase): - def setUp(self): - clear_functions() - - self.app = fastapi.FastAPI() - - @self.app.get("/") - def root(): - return "OK" - - self.dispatch = create_dispatch_instance( - self.app, endpoint="https://127.0.0.1:9999" - ) - self.http_client = TestClient(self.app) - self.client = create_endpoint_client(self.app) - - def execute( - self, func: Function, input=None, state=None, calls=None - ) -> function_pb.RunResponse: - """Test helper to invoke coroutines on the local server.""" - req = function_pb.RunRequest(function=func.name) - - if input is not None: - any = any_pickle(input) - req.input.CopyFrom(any) - if state is not None: - any = any_pickle(state) - req.poll_result.typed_coroutine_state.CopyFrom(any) - if calls is not None: - for c in calls: - req.poll_result.results.append(c) - - resp = self.client.run(req) - self.assertIsInstance(resp, function_pb.RunResponse) - return resp - - def call(self, func: Function, *args, **kwargs) -> function_pb.RunResponse: - return self.execute(func, input=Arguments(args, kwargs)) - - def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: - req = function_pb.RunRequest( - function=call.function, - input=call.input, - ) - resp = self.client.run(req) - self.assertIsInstance(resp, function_pb.RunResponse) - - resp.exit.result.correlation_id = call.correlation_id - return resp.exit.result - - def test_no_input(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value("Hello World!") - - resp = self.execute(my_function) - - out = response_output(resp) - self.assertEqual(out, "Hello World!") - - def test_missing_coroutine(self): - req = function_pb.RunRequest( - function="does-not-exist", - ) - - with self.assertRaises(httpx.HTTPStatusError) as cm: - self.client.run(req) - self.assertEqual(cm.exception.response.status_code, 404) - - def test_string_input(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value(f"You sent '{input.input}'") - - resp = self.execute(my_function, input="cool stuff") - out = response_output(resp) - self.assertEqual(out, "You sent 'cool stuff'") - - def test_error_on_access_state_in_first_call(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - try: - print(input.coroutine_state) - except ValueError: - return Output.error( - Error.from_exception( - ValueError("This input is for a first function call") - ) - ) - return Output.value("not reached") - - resp = self.execute(my_function, input="cool stuff") - self.assertEqual("ValueError", resp.exit.result.error.type) - self.assertEqual( - "This input is for a first function call", resp.exit.result.error.message - ) - - def test_error_on_access_input_in_second_call(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - if input.is_first_call: - return Output.poll(coroutine_state=b"42") - try: - print(input.input) - except ValueError: - return Output.error( - Error.from_exception( - ValueError("This input is for a resumed coroutine") - ) - ) - return Output.value("not reached") - - resp = self.execute(my_function, input="cool stuff") - state = any_unpickle(resp.poll.typed_coroutine_state) - self.assertEqual(b"42", state) - - resp = self.execute(my_function, state=state) - self.assertEqual("ValueError", resp.exit.result.error.type) - self.assertEqual( - "This input is for a resumed coroutine", resp.exit.result.error.message - ) - - def test_duplicate_coro(self): - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value("Do one thing") - - with self.assertRaises(ValueError): - - @self.dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value("Do something else") - - def test_two_simple_coroutines(self): - @self.dispatch.primitive_function - async def echoroutine(input: Input) -> Output: - return Output.value(f"Echo: '{input.input}'") - - @self.dispatch.primitive_function - async def len_coroutine(input: Input) -> Output: - return Output.value(f"Length: {len(input.input)}") - - data = "cool stuff" - resp = self.execute(echoroutine, input=data) - out = response_output(resp) - self.assertEqual(out, "Echo: 'cool stuff'") - - resp = self.execute(len_coroutine, input=data) - out = response_output(resp) - self.assertEqual(out, "Length: 10") - - def test_coroutine_with_state(self): - @self.dispatch.primitive_function - async def coroutine3(input: Input) -> Output: - if input.is_first_call: - counter = input.input - else: - counter = input.coroutine_state - counter -= 1 - if counter <= 0: - return Output.value("done") - return Output.poll(coroutine_state=counter) - - # first call - resp = self.execute(coroutine3, input=4) - state = any_unpickle(resp.poll.typed_coroutine_state) - self.assertEqual(state, 3) - - # resume, state = 3 - resp = self.execute(coroutine3, state=state) - state = any_unpickle(resp.poll.typed_coroutine_state) - self.assertEqual(state, 2) - - # resume, state = 2 - resp = self.execute(coroutine3, state=state) - state = any_unpickle(resp.poll.typed_coroutine_state) - self.assertEqual(state, 1) - - # resume, state = 1 - resp = self.execute(coroutine3, state=state) - out = response_output(resp) - self.assertEqual(out, "done") - - def test_coroutine_poll(self): - @self.dispatch.primitive_function - async def coro_compute_len(input: Input) -> Output: - return Output.value(len(input.input)) - - @self.dispatch.primitive_function - async def coroutine_main(input: Input) -> Output: - if input.is_first_call: - text: str = input.input - return Output.poll( - coroutine_state=text, - calls=[coro_compute_len._build_primitive_call(text)], - ) - text = input.coroutine_state - length = input.call_results[0].output - return Output.value(f"length={length} text='{text}'") - - resp = self.execute(coroutine_main, input="cool stuff") - - # main saved some state - state = any_unpickle(resp.poll.typed_coroutine_state) - self.assertEqual(state, "cool stuff") - # main asks for 1 call to compute_len - self.assertEqual(len(resp.poll.calls), 1) - call = resp.poll.calls[0] - self.assertEqual(call.function, coro_compute_len.name) - self.assertEqual(any_unpickle(call.input), "cool stuff") - - # make the requested compute_len - resp2 = self.proto_call(call) - # check the result is the terminal expected response - len_resp = any_unpickle(resp2.output) - self.assertEqual(10, len_resp) - - # resume main with the result - resp = self.execute(coroutine_main, state=state, calls=[resp2]) - # validate the final result - out = response_output(resp) - self.assertEqual("length=10 text='cool stuff'", out) - - def test_coroutine_poll_error(self): - @self.dispatch.primitive_function - async def coro_compute_len(input: Input) -> Output: - return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) - - @self.dispatch.primitive_function - async def coroutine_main(input: Input) -> Output: - if input.is_first_call: - text: str = input.input - return Output.poll( - coroutine_state=text, - calls=[coro_compute_len._build_primitive_call(text)], - ) - error = input.call_results[0].error - if error is not None: - return Output.value(f"msg={error.message} type='{error.type}'") - else: - raise RuntimeError(f"unexpected call results: {input.call_results}") - - resp = self.execute(coroutine_main, input="cool stuff") - - # main saved some state - state = any_unpickle(resp.poll.typed_coroutine_state) - self.assertEqual(state, "cool stuff") - # main asks for 1 call to compute_len - self.assertEqual(len(resp.poll.calls), 1) - call = resp.poll.calls[0] - self.assertEqual(call.function, coro_compute_len.name) - self.assertEqual(any_unpickle(call.input), "cool stuff") - - # make the requested compute_len - resp2 = self.proto_call(call) - - # resume main with the result - resp = self.execute(coroutine_main, state=state, calls=[resp2]) - # validate the final result - out = response_output(resp) - self.assertEqual(out, "msg=Dead type='type'") - - def test_coroutine_error(self): - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead")) - - resp = self.execute(mycoro) - self.assertEqual("sometype", resp.exit.result.error.type) - self.assertEqual("dead", resp.exit.result.error.message) - - def test_coroutine_expected_exception(self): - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - try: - 1 / 0 - except ZeroDivisionError as e: - return Output.error(Error.from_exception(e)) - self.fail("should not reach here") - - resp = self.execute(mycoro) - self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) - self.assertEqual("division by zero", resp.exit.result.error.message) - self.assertEqual(Status.PERMANENT_ERROR, resp.status) - - def test_coroutine_unexpected_exception(self): - @self.dispatch.function - def mycoro(): - 1 / 0 - self.fail("should not reach here") - - resp = self.call(mycoro) - self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) - self.assertEqual("division by zero", resp.exit.result.error.message) - self.assertEqual(Status.PERMANENT_ERROR, resp.status) - - def test_specific_status(self): - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - return Output.error(Error(Status.THROTTLED, "foo", "bar")) - - resp = self.execute(mycoro) - self.assertEqual("foo", resp.exit.result.error.type) - self.assertEqual("bar", resp.exit.result.error.message) - self.assertEqual(Status.THROTTLED, resp.status) - - def test_tailcall(self): - @self.dispatch.function - def other_coroutine(value: Any) -> str: - return f"Hello {value}" - - @self.dispatch.primitive_function - async def mycoro(input: Input) -> Output: - return Output.tail_call(other_coroutine._build_primitive_call(42)) - - resp = self.call(mycoro) - self.assertEqual(other_coroutine.name, resp.exit.tail_call.function) - self.assertEqual(42, any_unpickle(resp.exit.tail_call.input)) - - def test_library_error_categorization(self): - @self.dispatch.function - def get(path: str) -> httpx.Response: - http_response = self.http_client.get(path) - http_response.raise_for_status() - return http_response - - resp = self.call(get, "/") - self.assertEqual(Status.OK, Status(resp.status)) - http_response = any_unpickle(resp.exit.result.output) - self.assertEqual("application/json", http_response.headers["content-type"]) - self.assertEqual('"OK"', http_response.text) - - resp = self.call(get, "/missing") - self.assertEqual(Status.NOT_FOUND, Status(resp.status)) - - def test_library_output_categorization(self): - @self.dispatch.function - def get(path: str) -> httpx.Response: - http_response = self.http_client.get(path) - http_response.status_code = 429 - return http_response - - resp = self.call(get, "/") - self.assertEqual(Status.THROTTLED, Status(resp.status)) - http_response = any_unpickle(resp.exit.result.output) - self.assertEqual("application/json", http_response.headers["content-type"]) - self.assertEqual('"OK"', http_response.text) +# class TestCoroutine(unittest.TestCase): +# def setUp(self): +# clear_functions() + +# self.app = fastapi.FastAPI() + +# @self.app.get("/") +# def root(): +# return "OK" + +# self.dispatch = create_dispatch_instance( +# self.app, endpoint="https://127.0.0.1:9999" +# ) +# self.http_client = TestClient(self.app) +# self.client = create_endpoint_client(self.app) + +# def tearDown(self): +# self.dispatch.registry.close() + +# def execute( +# self, func: Function, input=None, state=None, calls=None +# ) -> function_pb.RunResponse: +# """Test helper to invoke coroutines on the local server.""" +# req = function_pb.RunRequest(function=func.name) + +# if input is not None: +# input_bytes = pickle.dumps(input) +# input_any = google.protobuf.any_pb2.Any() +# input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes)) +# req.input.CopyFrom(input_any) +# if state is not None: +# req.poll_result.coroutine_state = state +# if calls is not None: +# for c in calls: +# req.poll_result.results.append(c) + +# resp = self.client.run(req) +# self.assertIsInstance(resp, function_pb.RunResponse) +# return resp + +# def call(self, func: Function, *args, **kwargs) -> function_pb.RunResponse: +# return self.execute(func, input=Arguments(args, kwargs)) + +# def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: +# req = function_pb.RunRequest( +# function=call.function, +# input=call.input, +# ) +# resp = self.client.run(req) +# self.assertIsInstance(resp, function_pb.RunResponse) + +# # Assert the response is terminal. Good enough until the test client can +# # orchestrate coroutines. +# self.assertTrue(len(resp.poll.coroutine_state) == 0) + +# resp.exit.result.correlation_id = call.correlation_id +# return resp.exit.result + +# def test_no_input(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value("Hello World!") + +# resp = self.execute(my_function) + +# out = response_output(resp) +# self.assertEqual(out, "Hello World!") + +# def test_missing_coroutine(self): +# req = function_pb.RunRequest( +# function="does-not-exist", +# ) + +# with self.assertRaises(httpx.HTTPStatusError) as cm: +# self.client.run(req) +# self.assertEqual(cm.exception.response.status_code, 404) + +# def test_string_input(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value(f"You sent '{input.input}'") + +# resp = self.execute(my_function, input="cool stuff") +# out = response_output(resp) +# self.assertEqual(out, "You sent 'cool stuff'") + +# def test_error_on_access_state_in_first_call(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# try: +# print(input.coroutine_state) +# except ValueError: +# return Output.error( +# Error.from_exception( +# ValueError("This input is for a first function call") +# ) +# ) +# return Output.value("not reached") + +# resp = self.execute(my_function, input="cool stuff") +# self.assertEqual("ValueError", resp.exit.result.error.type) +# self.assertEqual( +# "This input is for a first function call", resp.exit.result.error.message +# ) + +# def test_error_on_access_input_in_second_call(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# if input.is_first_call: +# return Output.poll(coroutine_state=b"42") +# try: +# print(input.input) +# except ValueError: +# return Output.error( +# Error.from_exception( +# ValueError("This input is for a resumed coroutine") +# ) +# ) +# return Output.value("not reached") + +# resp = self.execute(my_function, input="cool stuff") +# self.assertEqual(b"42", resp.poll.coroutine_state) + +# resp = self.execute(my_function, state=resp.poll.coroutine_state) +# self.assertEqual("ValueError", resp.exit.result.error.type) +# self.assertEqual( +# "This input is for a resumed coroutine", resp.exit.result.error.message +# ) + +# def test_duplicate_coro(self): +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value("Do one thing") + +# with self.assertRaises(ValueError): + +# @self.dispatch.primitive_function +# async def my_function(input: Input) -> Output: +# return Output.value("Do something else") + +# def test_two_simple_coroutines(self): +# @self.dispatch.primitive_function +# async def echoroutine(input: Input) -> Output: +# return Output.value(f"Echo: '{input.input}'") + +# @self.dispatch.primitive_function +# async def len_coroutine(input: Input) -> Output: +# return Output.value(f"Length: {len(input.input)}") + +# data = "cool stuff" +# resp = self.execute(echoroutine, input=data) +# out = response_output(resp) +# self.assertEqual(out, "Echo: 'cool stuff'") + +# resp = self.execute(len_coroutine, input=data) +# out = response_output(resp) +# self.assertEqual(out, "Length: 10") + +# def test_coroutine_with_state(self): +# @self.dispatch.primitive_function +# async def coroutine3(input: Input) -> Output: +# if input.is_first_call: +# counter = input.input +# else: +# (counter,) = struct.unpack("@i", input.coroutine_state) +# counter -= 1 +# if counter <= 0: +# return Output.value("done") +# coroutine_state = struct.pack("@i", counter) +# return Output.poll(coroutine_state=coroutine_state) + +# # first call +# resp = self.execute(coroutine3, input=4) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) + +# # resume, state = 3 +# resp = self.execute(coroutine3, state=state) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) + +# # resume, state = 2 +# resp = self.execute(coroutine3, state=state) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) + +# # resume, state = 1 +# resp = self.execute(coroutine3, state=state) +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) == 0) +# out = response_output(resp) +# self.assertEqual(out, "done") + +# def test_coroutine_poll(self): +# @self.dispatch.primitive_function +# async def coro_compute_len(input: Input) -> Output: +# return Output.value(len(input.input)) + +# @self.dispatch.primitive_function +# async def coroutine_main(input: Input) -> Output: +# if input.is_first_call: +# text: str = input.input +# return Output.poll( +# coroutine_state=text.encode(), +# calls=[coro_compute_len._build_primitive_call(text)], +# ) +# text = input.coroutine_state.decode() +# length = input.call_results[0].output +# return Output.value(f"length={length} text='{text}'") + +# resp = self.execute(coroutine_main, input="cool stuff") + +# # main saved some state +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) +# # main asks for 1 call to compute_len +# self.assertEqual(len(resp.poll.calls), 1) +# call = resp.poll.calls[0] +# self.assertEqual(call.function, coro_compute_len.name) +# self.assertEqual(any_unpickle(call.input), "cool stuff") + +# # make the requested compute_len +# resp2 = self.proto_call(call) +# # check the result is the terminal expected response +# len_resp = any_unpickle(resp2.output) +# self.assertEqual(10, len_resp) + +# # resume main with the result +# resp = self.execute(coroutine_main, state=state, calls=[resp2]) +# # validate the final result +# self.assertTrue(len(resp.poll.coroutine_state) == 0) +# out = response_output(resp) +# self.assertEqual("length=10 text='cool stuff'", out) + +# def test_coroutine_poll_error(self): +# @self.dispatch.primitive_function +# async def coro_compute_len(input: Input) -> Output: +# return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) + +# @self.dispatch.primitive_function +# async def coroutine_main(input: Input) -> Output: +# if input.is_first_call: +# text: str = input.input +# return Output.poll( +# coroutine_state=text.encode(), +# calls=[coro_compute_len._build_primitive_call(text)], +# ) +# error = input.call_results[0].error +# if error is not None: +# return Output.value(f"msg={error.message} type='{error.type}'") +# else: +# raise RuntimeError(f"unexpected call results: {input.call_results}") + +# resp = self.execute(coroutine_main, input="cool stuff") + +# # main saved some state +# state = resp.poll.coroutine_state +# self.assertTrue(len(state) > 0) +# # main asks for 1 call to compute_len +# self.assertEqual(len(resp.poll.calls), 1) +# call = resp.poll.calls[0] +# self.assertEqual(call.function, coro_compute_len.name) +# self.assertEqual(any_unpickle(call.input), "cool stuff") + +# # make the requested compute_len +# resp2 = self.proto_call(call) + +# # resume main with the result +# resp = self.execute(coroutine_main, state=state, calls=[resp2]) +# # validate the final result +# self.assertTrue(len(resp.poll.coroutine_state) == 0) +# out = response_output(resp) +# self.assertEqual(out, "msg=Dead type='type'") + +# def test_coroutine_error(self): +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead")) + +# resp = self.execute(mycoro) +# self.assertEqual("sometype", resp.exit.result.error.type) +# self.assertEqual("dead", resp.exit.result.error.message) + +# def test_coroutine_expected_exception(self): +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# try: +# 1 / 0 +# except ZeroDivisionError as e: +# return Output.error(Error.from_exception(e)) +# self.fail("should not reach here") + +# resp = self.execute(mycoro) +# self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) +# self.assertEqual("division by zero", resp.exit.result.error.message) +# self.assertEqual(Status.PERMANENT_ERROR, resp.status) + +# def test_coroutine_unexpected_exception(self): +# @self.dispatch.function +# def mycoro(): +# 1 / 0 +# self.fail("should not reach here") + +# resp = self.call(mycoro) +# self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) +# self.assertEqual("division by zero", resp.exit.result.error.message) +# self.assertEqual(Status.PERMANENT_ERROR, resp.status) + +# def test_specific_status(self): +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# return Output.error(Error(Status.THROTTLED, "foo", "bar")) + +# resp = self.execute(mycoro) +# self.assertEqual("foo", resp.exit.result.error.type) +# self.assertEqual("bar", resp.exit.result.error.message) +# self.assertEqual(Status.THROTTLED, resp.status) + +# def test_tailcall(self): +# @self.dispatch.function +# def other_coroutine(value: Any) -> str: +# return f"Hello {value}" + +# @self.dispatch.primitive_function +# async def mycoro(input: Input) -> Output: +# return Output.tail_call(other_coroutine._build_primitive_call(42)) + +# resp = self.call(mycoro) +# self.assertEqual(other_coroutine.name, resp.exit.tail_call.function) +# self.assertEqual(42, any_unpickle(resp.exit.tail_call.input)) + +# def test_library_error_categorization(self): +# @self.dispatch.function +# def get(path: str) -> httpx.Response: +# http_response = self.http_client.get(path) +# http_response.raise_for_status() +# return http_response + +# resp = self.call(get, "/") +# self.assertEqual(Status.OK, Status(resp.status)) +# http_response = any_unpickle(resp.exit.result.output) +# self.assertEqual("application/json", http_response.headers["content-type"]) +# self.assertEqual('"OK"', http_response.text) + +# resp = self.call(get, "/missing") +# self.assertEqual(Status.NOT_FOUND, Status(resp.status)) + +# def test_library_output_categorization(self): +# @self.dispatch.function +# def get(path: str) -> httpx.Response: +# http_response = self.http_client.get(path) +# http_response.status_code = 429 +# return http_response + +# resp = self.call(get, "/") +# self.assertEqual(Status.THROTTLED, Status(resp.status)) +# http_response = any_unpickle(resp.exit.result.output) +# self.assertEqual("application/json", http_response.headers["content-type"]) +# self.assertEqual('"OK"', http_response.text) diff --git a/tests/test_flask.py b/tests/test_flask.py index b78c56be..494199bb 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -1,138 +1,28 @@ -import asyncio -import base64 -import os -import pickle -import struct -import unittest -from typing import Any, Optional -from unittest import mock +from wsgiref.simple_server import make_server -import google.protobuf.any_pb2 -import google.protobuf.wrappers_pb2 -from cryptography.hazmat.primitives.asymmetric.ed25519 import ( - Ed25519PrivateKey, - Ed25519PublicKey, -) from flask import Flask import dispatch -from dispatch.experimental.durable.registry import clear_functions +import dispatch.test from dispatch.flask import Dispatch -from dispatch.function import Arguments, Error, Function, Input, Output -from dispatch.proto import _any_unpickle as any_unpickle -from dispatch.proto import _pb_any_pickle as any_pickle -from dispatch.sdk.v1 import call_pb2 as call_pb -from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.signature import ( - parse_verification_key, - private_key_from_pem, - public_key_from_pem, -) -from dispatch.status import Status -from dispatch.test import Client, DispatchServer, DispatchService, EndpointClient -from dispatch.test.flask import http_client +from dispatch.function import Registry -def create_dispatch_instance(app: Flask, endpoint: str): - return Dispatch( - app, - endpoint=endpoint, - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", - ) +class TestFlask(dispatch.test.TestCase): + def dispatch_test_init(self, reg: Registry) -> str: + host = "127.0.0.1" + port = 56789 -def create_endpoint_client(app: Flask, signing_key: Optional[Ed25519PrivateKey] = None): - return EndpointClient(http_client(app), signing_key) + app = Flask("test") + dispatch = Dispatch(app, registry=reg) + self.wsgi = make_server(host, port, app) + return f"http://{host}:{port}" -class TestFlask(unittest.TestCase): - def test_flask(self): - app = Flask(__name__) - dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/") + def dispatch_test_run(self): + self.wsgi.serve_forever(poll_interval=0.05) - @dispatch.primitive_function - async def my_function(input: Input) -> Output: - return Output.value( - f"You told me: '{input.input}' ({len(input.input)} characters)" - ) - - client = create_endpoint_client(app) - - req = function_pb.RunRequest( - function=my_function.name, input=any_pickle("Hello World!") - ) - - resp = client.run(req) - - self.assertIsInstance(resp, function_pb.RunResponse) - - output = any_unpickle(resp.exit.result.output) - - self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") - - -signing_key = private_key_from_pem( - """ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF ------END PRIVATE KEY----- -""" -) - -verification_key = public_key_from_pem( - """ ------BEGIN PUBLIC KEY----- -MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs= ------END PUBLIC KEY----- -""" -) - - -class TestFlaskE2E(unittest.TestCase): - def setUp(self): - self.endpoint_app = Flask(__name__) - endpoint_client = create_endpoint_client(self.endpoint_app, signing_key) - - api_key = "0000000000000000" - self.dispatch_service = DispatchService( - endpoint_client, api_key, collect_roundtrips=True - ) - self.dispatch_server = DispatchServer(self.dispatch_service) - self.dispatch_client = Client(api_key, api_url=self.dispatch_server.url) - - self.dispatch = Dispatch( - self.endpoint_app, - endpoint="http://function-service", # unused - verification_key=verification_key, - api_key=api_key, - api_url=self.dispatch_server.url, - ) - - self.dispatch_server.start() - - def tearDown(self): - self.dispatch_server.stop() - - def test_simple_end_to_end(self): - # The Flask server. - @self.dispatch.function - def my_function(name: str) -> str: - return f"Hello world: {name}" - - call = my_function.build_call("52") - self.assertEqual(call.function.split(".")[-1], "my_function") - - # The client. - [dispatch_id] = asyncio.run( - self.dispatch_client.dispatch([my_function.build_call("52")]) - ) - - # Simulate execution for testing purposes. - self.dispatch_service.dispatch_calls() - - # Validate results. - roundtrips = self.dispatch_service.roundtrips[dispatch_id] - self.assertEqual(len(roundtrips), 1) - _, response = roundtrips[0] - self.assertEqual(any_unpickle(response.exit.result.output), "Hello world: 52") + def dispatch_test_stop(self): + self.wsgi.shutdown() + self.wsgi.server_close() From 42810bc1e896eea24abfe55fe7b9cca997e8f8dd Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 11 Jun 2024 17:12:08 -0700 Subject: [PATCH 03/34] port http tests to generic dispatch test suite Signed-off-by: Achille Roussel --- src/dispatch/fastapi.py | 12 +++++- src/dispatch/flask.py | 10 ++++- src/dispatch/http.py | 33 ++++++++-------- src/dispatch/test/__init__.py | 60 +++++++++++++++++++++++++++++ tests/test_http.py | 71 +++++++++++++++++++++++++++++++++++ 5 files changed, 168 insertions(+), 18 deletions(-) diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 09018493..660ebf53 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -25,7 +25,11 @@ def read_root(): import fastapi.responses from dispatch.function import Registry -from dispatch.http import FunctionServiceError, function_service_run +from dispatch.http import ( + FunctionServiceError, + function_service_run, + validate_content_length, +) from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) @@ -97,6 +101,12 @@ async def on_error(request: fastapi.Request, exc: FunctionServiceError): "/Run", ) async def execute(request: fastapi.Request): + valid, reason = validate_content_length( + int(request.headers.get("content-length", 0)) + ) + if not valid: + raise FunctionServiceError(400, "invalid_argument", reason) + # Raw request body bytes are only available through the underlying # starlette Request object's body method, which returns an awaitable, # forcing execute() to be async. diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index b65270a4..1ce08676 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -24,7 +24,11 @@ def read_root(): from flask import Flask, make_response, request from dispatch.function import Registry -from dispatch.http import FunctionServiceError, function_service_run +from dispatch.http import ( + FunctionServiceError, + function_service_run, + validate_content_length, +) from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) @@ -87,6 +91,10 @@ def _handle_error(self, exc: FunctionServiceError): return {"code": exc.code, "message": exc.message}, exc.status def _execute(self): + valid, reason = validate_content_length(request.content_length or 0) + if not valid: + return {"code": "invalid_argument", "message": reason}, 400 + data: bytes = request.get_data(cache=False) content = asyncio.run( diff --git a/src/dispatch/http.py b/src/dispatch/http.py index bfc4c5e8..ccbcfe94 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -5,7 +5,7 @@ import os from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Iterable, List, Mapping, Optional, Union +from typing import Iterable, List, Mapping, Optional, Tuple, Union from aiohttp import web from http_message_signatures import InvalidSignature @@ -34,6 +34,16 @@ def __init__(self, status, code, message): self.message = message +def validate_content_length(content_length: int) -> Tuple[bool, str]: + if content_length == 0: + return False, "content length is required" + if content_length < 0: + return False, "content length is negative" + if content_length > 16_000_000: + return False, "content length is too large" + return True, "" + + class FunctionService(BaseHTTPRequestHandler): def __init__( @@ -78,14 +88,9 @@ def do_POST(self): return content_length = int(self.headers.get("Content-Length", 0)) - if content_length == 0: - self.send_error_response_invalid_argument("content length is required") - return - if content_length < 0: - self.send_error_response_invalid_argument("content length is negative") - return - if content_length > 16_000_000: - self.send_error_response_invalid_argument("content length is too large") + valid, reason = validate_content_length(content_length) + if not valid: + self.send_error_response_invalid_argument(reason) return data: bytes = self.rfile.read(content_length) @@ -229,13 +234,9 @@ async def function_service_run_handler( function_registry: Registry, verification_key: Optional[Ed25519PublicKey], ) -> web.Response: - content_length = request.content_length - if content_length is None or content_length == 0: - return make_error_response_invalid_argument("content length is required") - if content_length < 0: - return make_error_response_invalid_argument("content length is negative") - if content_length > 16_000_000: - return make_error_response_invalid_argument("content length is too large") + valid, reason = validate_content_length(request.content_length or 0) + if not valid: + return make_error_response_invalid_argument(reason) data: bytes = await request.read() try: diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 0f07e521..48a2f388 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -1,4 +1,5 @@ import asyncio +import json import threading import unittest from datetime import datetime, timedelta @@ -289,6 +290,16 @@ def wrapper(self: T): return wrapper +def aiotest( + fn: Callable[["TestCase"], Coroutine[Any, Any, None]] +) -> Callable[["TestCase"], None]: + @wraps(fn) + def wrapper(self): + self.loop.run_until_complete(fn(self)) + + return wrapper + + class TestCase(unittest.TestCase): def dispatch_test_init(self, api_key: str, api_url: str) -> BaseRegistry: @@ -325,6 +336,42 @@ def tearDown(self): self.loop.run_until_complete(self.loop.shutdown_asyncgens()) self.loop.close() + # TODO: let's figure out how to get rid of this global registry + # state at some point, which forces tests to be run sequentially. + dispatch.experimental.durable.registry.clear_functions() + + @aiotest + async def test_content_length_missing(self): + async with aiohttp.ClientSession( + request_class=ClientRequestContentLengthMissing + ) as session: + async with await session.post( + f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run", + ) as resp: + data = await resp.read() + print(data) + assert resp.status == 400 + assert json.loads(data) == { + "code": "invalid_argument", + "message": "content length is required", + } + + @aiotest + async def test_content_length_too_large(self): + async with aiohttp.ClientSession( + request_class=ClientRequestContentLengthTooLarge + ) as session: + async with await session.post( + f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run", + ) as resp: + data = await resp.read() + print(data) + assert resp.status == 400 + assert json.loads(data) == { + "code": "invalid_argument", + "message": "content length is too large", + } + def test_simple_end_to_end(self): @self.dispatch.function def my_function(name: str) -> str: @@ -335,3 +382,16 @@ async def test(): assert msg == "Hello world: 52" self.loop.run_until_complete(test()) + + +class ClientRequestContentLengthMissing(aiohttp.ClientRequest): + def update_headers(self, skip_auto_headers): + super().update_headers(skip_auto_headers) + if "Content-Length" in self.headers: + del self.headers["Content-Length"] + + +class ClientRequestContentLengthTooLarge(aiohttp.ClientRequest): + def update_headers(self, skip_auto_headers): + super().update_headers(skip_auto_headers) + self.headers["Content-Length"] = "16000001" diff --git a/tests/test_http.py b/tests/test_http.py index 34d27ab4..152a8261 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,5 +1,6 @@ import asyncio import socket +<<<<<<< HEAD from http.server import HTTPServer import dispatch.test @@ -55,3 +56,73 @@ def dispatch_test_run(self): def dispatch_test_stop(self): self.aioloop.get_loop().call_soon_threadsafe(self.aiowait.set) +======= +from datetime import datetime +from http.server import HTTPServer + +import dispatch.test +from dispatch.function import Registry +from dispatch.http import Dispatch, Server + + +class TestHTTP(dispatch.test.TestCase): + + def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", 0)) + sock.listen(128) + + (host, port) = sock.getsockname() + + reg = Registry( + endpoint=f"http://{host}:{port}", + api_key=api_key, + api_url=api_url, + ) + + self.httpserver = HTTPServer( + server_address=(host, port), + RequestHandlerClass=Dispatch(reg), + bind_and_activate=False, + ) + self.httpserver.socket = sock + return reg + + def dispatch_test_run(self): + self.httpserver.serve_forever(poll_interval=0.05) + + def dispatch_test_stop(self): + self.httpserver.shutdown() + self.httpserver.server_close() + + +class TestAIOHTTP(dispatch.test.TestCase): + + def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: + host = "localhost" + port = 0 + + reg = Registry( + endpoint=f"http://{host}:{port}", + api_key=api_key, + api_url=api_url, + ) + + self.aioloop = asyncio.new_event_loop() + self.aiohttp = Server(host, port, Dispatch(reg)) + self.aioloop.run_until_complete(self.aiohttp.start()) + + reg.endpoint = f"http://{self.aiohttp.host}:{self.aiohttp.port}" + return reg + + def dispatch_test_run(self): + self.aioloop.run_forever() + + def dispatch_test_stop(self): + def stop(): + self.aiohttp.stop() + self.aioloop.stop() + + self.aioloop.call_soon_threadsafe(stop) +>>>>>>> ed50efc (port http tests to generic dispatch test suite) From 731d96d83f1fbe38beeb428ab6639f839368075a Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 11 Jun 2024 17:44:14 -0700 Subject: [PATCH 04/34] fix slow FastAPI tests Signed-off-by: Achille Roussel --- src/dispatch/test/__init__.py | 21 +++++++---- tests/test_http.py | 71 ----------------------------------- 2 files changed, 14 insertions(+), 78 deletions(-) diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 48a2f388..1b8c14ac 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -13,7 +13,7 @@ import dispatch.experimental.durable.registry from dispatch.function import Client as BaseClient -from dispatch.function import ClientError +from dispatch.function import ClientError, Input, Output from dispatch.function import Registry as BaseRegistry from dispatch.http import Dispatch from dispatch.http import Server as BaseServer @@ -372,16 +372,23 @@ async def test_content_length_too_large(self): "message": "content length is too large", } - def test_simple_end_to_end(self): + @aiotest + async def test_call_function_no_input(self): + @self.dispatch.function + def my_function() -> str: + return "Hello World!" + + ret = await my_function() + assert ret == "Hello World!" + + @aiotest + async def test_call_function_with_input(self): @self.dispatch.function def my_function(name: str) -> str: return f"Hello world: {name}" - async def test(): - msg = await my_function("52") - assert msg == "Hello world: 52" - - self.loop.run_until_complete(test()) + ret = await my_function("52") + assert ret == "Hello world: 52" class ClientRequestContentLengthMissing(aiohttp.ClientRequest): diff --git a/tests/test_http.py b/tests/test_http.py index 152a8261..34d27ab4 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,6 +1,5 @@ import asyncio import socket -<<<<<<< HEAD from http.server import HTTPServer import dispatch.test @@ -56,73 +55,3 @@ def dispatch_test_run(self): def dispatch_test_stop(self): self.aioloop.get_loop().call_soon_threadsafe(self.aiowait.set) -======= -from datetime import datetime -from http.server import HTTPServer - -import dispatch.test -from dispatch.function import Registry -from dispatch.http import Dispatch, Server - - -class TestHTTP(dispatch.test.TestCase): - - def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("localhost", 0)) - sock.listen(128) - - (host, port) = sock.getsockname() - - reg = Registry( - endpoint=f"http://{host}:{port}", - api_key=api_key, - api_url=api_url, - ) - - self.httpserver = HTTPServer( - server_address=(host, port), - RequestHandlerClass=Dispatch(reg), - bind_and_activate=False, - ) - self.httpserver.socket = sock - return reg - - def dispatch_test_run(self): - self.httpserver.serve_forever(poll_interval=0.05) - - def dispatch_test_stop(self): - self.httpserver.shutdown() - self.httpserver.server_close() - - -class TestAIOHTTP(dispatch.test.TestCase): - - def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: - host = "localhost" - port = 0 - - reg = Registry( - endpoint=f"http://{host}:{port}", - api_key=api_key, - api_url=api_url, - ) - - self.aioloop = asyncio.new_event_loop() - self.aiohttp = Server(host, port, Dispatch(reg)) - self.aioloop.run_until_complete(self.aiohttp.start()) - - reg.endpoint = f"http://{self.aiohttp.host}:{self.aiohttp.port}" - return reg - - def dispatch_test_run(self): - self.aioloop.run_forever() - - def dispatch_test_stop(self): - def stop(): - self.aiohttp.stop() - self.aioloop.stop() - - self.aioloop.call_soon_threadsafe(stop) ->>>>>>> ed50efc (port http tests to generic dispatch test suite) From 69e05bc2962bf5a00b8e9be7c4147c118f6cfdcd Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 12 Jun 2024 14:34:59 -0700 Subject: [PATCH 05/34] port some of the tests from TestCoroutine to TestCase + fix bugs Signed-off-by: Achille Roussel --- src/dispatch/flask.py | 65 ++++++++++- src/dispatch/function.py | 30 +---- src/dispatch/scheduler.py | 18 +++ src/dispatch/test/__init__.py | 202 ++++++++++++++++++++++++++++------ 4 files changed, 253 insertions(+), 62 deletions(-) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 1ce08676..24986826 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -19,6 +19,8 @@ def read_root(): import asyncio import logging +import threading +# from queue import Queue from typing import Optional, Union from flask import Flask, make_response, request @@ -87,6 +89,42 @@ def __init__( app.errorhandler(FunctionServiceError)(self._handle_error) app.post("/dispatch.sdk.v1.FunctionService/Run")(self._execute) + # TODO: earlier experiment I ran because it seemed like tasks created + # by the /Dispatch endpoint were canceled when calls to /Wait were made. + # + # After further investigation, it might have been caused by a bug when + # setting the thread local state indicating that we are being invoked + # from a scheduler thread, which resulted in unnecessary dispatch calls. + # + # I'm keeping the code around for now in case it ends up being needed in + # the short term. Feel free to remove if you run into this comment and + # it's no longer relevant. + # --- + # Here we have to use one event loop for the whole application to allow + # tasks spawned by request handlers to persist after the request is done. + # + # This is essential for tests to pass when using the /Dispatch and /Wait + # endpoints to wait on function results. + # self._loop = asyncio.new_event_loop() + # self._thread = threading.Thread(target=self._run_event_loop) + # self._thread.start() + + # def close(self): + # self._loop.call_soon_threadsafe(self._loop.stop) + # self._thread.join() + + # def __enter__(self): + # return self + + # def __exit__(self, exc_type, exc_value, traceback): + # self.close() + + # def _run_event_loop(self): + # asyncio.set_event_loop(self._loop) + # self._loop.run_forever() + # self._loop.run_until_complete(self._loop.shutdown_asyncgens()) + # self._loop.close() + def _handle_error(self, exc: FunctionServiceError): return {"code": exc.code, "message": exc.message}, exc.status @@ -105,9 +143,34 @@ def _execute(self): data, self, self._verification_key, - ), + ) ) + # queue = Queue[asyncio.Task](maxsize=1) + # + # url, method, headers = request.url, request.method, dict(request.headers) + # def execute_task(): + # task = self._loop.create_task( + # function_service_run( + # url, + # method, + # headers, + # data, + # self, + # self._verification_key, + # ) + # ) + # task.add_done_callback(queue.put) + + # self._loop.call_soon_threadsafe(execute_task) + # task: asyncio.Task = queue.get() + + # exception = task.exception() + # if exception is not None: + # raise exception + + # content: bytes = task.result() + res = make_response(content) res.content_type = "application/proto" return res diff --git a/src/dispatch/function.py b/src/dispatch/function.py index f5807a15..5ad6647e 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -5,7 +5,6 @@ import json import logging import os -import threading from functools import wraps from types import CoroutineType from typing import ( @@ -34,7 +33,7 @@ from dispatch.experimental.durable import durable from dispatch.id import DispatchID from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall -from dispatch.scheduler import OneShotScheduler +from dispatch.scheduler import OneShotScheduler, in_function_call logger = logging.getLogger(__name__) @@ -55,29 +54,6 @@ def current_session() -> aiohttp.ClientSession: return DEFAULT_SESSION -class ThreadContext(threading.local): - in_function_call: bool - - def __init__(self): - self.in_function_call = False - - -thread_context = ThreadContext() - - -def function(func: Callable[P, T]) -> Callable[P, T]: - def scope(*args: P.args, **kwargs: P.kwargs) -> T: - if thread_context.in_function_call: - raise RuntimeError("recursively entered a dispatch function entry point") - thread_context.in_function_call = True - try: - return func(*args, **kwargs) - finally: - thread_context.in_function_call = False - - return scope - - PrimitiveFunctionType: TypeAlias = Callable[[Input], Awaitable[Output]] """A primitive function is a function that accepts a dispatch.proto.Input and unconditionally returns a dispatch.proto.Output. It must not raise @@ -163,7 +139,7 @@ async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T: async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: """Call the function asynchronously (through Dispatch), and return a coroutine that can be awaited to retrieve the call result.""" - if thread_context.in_function_call: + if in_function_call(): return await self._func_indirect(*args, **kwargs) call = self.build_call(*args, **kwargs) @@ -279,7 +255,6 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: func = durable(func) @wraps(func) - @function async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, func, *args, **kwargs) @@ -294,7 +269,6 @@ def _register_coroutine( func = durable(func) @wraps(func) - @function async def primitive_func(input: Input) -> Output: return await OneShotScheduler(func).run(input) diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 75b2e54c..112428f0 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -2,6 +2,7 @@ import logging import pickle import sys +import threading from dataclasses import dataclass, field from types import coroutine from typing import ( @@ -32,6 +33,20 @@ CorrelationID: TypeAlias = int +class ThreadLocal(threading.local): + in_function_call: bool + + def __init__(self): + self.in_function_call = False + + +thread_local = ThreadLocal() + + +def in_function_call() -> bool: + return thread_local.in_function_call + + @dataclass class CoroutineResult: """The result from running a coroutine to completion.""" @@ -328,12 +343,15 @@ def __init__( async def run(self, input: Input) -> Output: try: + thread_local.in_function_call = True return await self._run(input) except Exception as e: logger.exception( "unexpected exception occurred during coroutine scheduling" ) return Output.error(Error.from_exception(e)) + finally: + thread_local.in_function_call = False def _init_state(self, input: Input) -> State: logger.debug("starting main coroutine") diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 1b8c14ac..aa5eec51 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -49,8 +49,8 @@ R = TypeVar("R", bound=BaseRegistry) T = TypeVar("T") -DISPATCH_ENDPOINT_URL = "http://localhost:0" -DISPATCH_API_URL = "http://localhost:0" +DISPATCH_ENDPOINT_URL = "http://127.0.0.1:0" +DISPATCH_API_URL = "http://127.0.0.1:0" DISPATCH_API_KEY = "916CC3D280BB46DDBDA984B3DD10059A" @@ -75,7 +75,7 @@ def __init__(self): class Server(BaseServer): def __init__(self, app: web.Application): - super().__init__("localhost", 0, app) + super().__init__("127.0.0.1", 0, app) @property def url(self): @@ -189,9 +189,17 @@ def make_request(call: Call) -> RunRequest: if res.status != STATUS_OK: # TODO: emulate retries etc... + if ( + res.HasField("exit") + and res.exit.HasField("result") + and res.exit.result.HasField("error") + ): + error = res.exit.result.error + else: + error = Error(type="status", message=str(res.status)) return CallResult( dispatch_id=dispatch_id, - error=Error(type="status", message=str(res.status)), + error=error, ) if res.HasField("exit"): @@ -271,6 +279,21 @@ async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: return asyncio.run(main(reg, fn)) +# TODO: these decorators still need work, until we figure out serialization +# for cell objects, they are not very useful since the registry they receive +# as argument cannot be used to register dispatch functions. +# +# The simplest would be to use a global registry for external application tests, +# maybe we can figure out a way to make this easy with a syntax like: +# +# import main +# import dispatch.test +# +# @dispatch.test.function(main.dispatch) +# async def test_something(): +# ... +# +# (WIP) def function(fn: Callable[[Registry], Coroutine[Any, Any, None]]) -> Callable[[], None]: @wraps(fn) @@ -293,16 +316,37 @@ def wrapper(self: T): def aiotest( fn: Callable[["TestCase"], Coroutine[Any, Any, None]] ) -> Callable[["TestCase"], None]: + """Decorator to run tests declared as async methods of the TestCase class + using the event loop of the test case instance. + + This decorator is internal only, it shouldn't be exposed in the public API + of this module. + """ @wraps(fn) - def wrapper(self): - self.loop.run_until_complete(fn(self)) + def test(self): + self.server_loop.run_until_complete(fn(self)) - return wrapper + return test class TestCase(unittest.TestCase): + """TestCase implements the generic test suite used in dispatch-py to test + various integrations of the SDK with frameworks like FastAPI, Flask, etc... + + Applications typically don't use this class directly, the test suite is + mostly useful as an internal testing tool. + + Implementation of the test suite need to override the dispatch_test_init, + dispatch_test_run, and dispatch_test_stop methods to integrate with the + testing infrastructure (see the documentation of each of these methods for + more details). + """ def dispatch_test_init(self, api_key: str, api_url: str) -> BaseRegistry: + """Called to initialize each test case. The method returns the dispatch + function registry which can be used to register function instances + during tests. + """ raise NotImplementedError def dispatch_test_run(self): @@ -314,8 +358,8 @@ def dispatch_test_stop(self): def setUp(self): self.service = Service() self.server = Server(self.service) - self.loop = asyncio.new_event_loop() - self.loop.run_until_complete(self.server.start()) + self.server_loop = asyncio.new_event_loop() + self.server_loop.run_until_complete(self.server.start()) self.dispatch = self.dispatch_test_init( api_key=DISPATCH_API_KEY, api_url=self.server.url @@ -325,52 +369,73 @@ def setUp(self): api_url=self.dispatch.client.api_url.value, ) - self.thread = threading.Thread(target=self.dispatch_test_run) - self.thread.start() + self.client_thread = threading.Thread(target=self.dispatch_test_run) + self.client_thread.start() def tearDown(self): self.dispatch_test_stop() - self.thread.join() + self.client_thread.join() - self.loop.run_until_complete(self.service.close()) - self.loop.run_until_complete(self.loop.shutdown_asyncgens()) - self.loop.close() + self.server_loop.run_until_complete(self.service.close()) + self.server_loop.run_until_complete(self.server_loop.shutdown_asyncgens()) + self.server_loop.close() # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. dispatch.experimental.durable.registry.clear_functions() + @property + def function_service_run_url(self) -> str: + return f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run" + + def test_register_duplicate_functions(self): + @self.dispatch.function + def my_function(): ... + + with self.assertRaises(ValueError): + + @self.dispatch.function + def my_function(): ... + @aiotest async def test_content_length_missing(self): async with aiohttp.ClientSession( request_class=ClientRequestContentLengthMissing ) as session: - async with await session.post( - f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run", - ) as resp: + async with await session.post(self.function_service_run_url) as resp: data = await resp.read() - print(data) - assert resp.status == 400 - assert json.loads(data) == { - "code": "invalid_argument", - "message": "content length is required", - } + self.assertEqual(resp.status, 400) + self.assertEqual( + json.loads(data), + make_error_invalid_argument("content length is required"), + ) @aiotest async def test_content_length_too_large(self): async with aiohttp.ClientSession( request_class=ClientRequestContentLengthTooLarge ) as session: + async with await session.post(self.function_service_run_url) as resp: + data = await resp.read() + self.assertEqual(resp.status, 400) + self.assertEqual( + json.loads(data), + make_error_invalid_argument("content length is too large"), + ) + + @aiotest + async def test_call_function_missing(self): + async with aiohttp.ClientSession() as session: async with await session.post( - f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run", + self.function_service_run_url, + data=RunRequest(function="does-not-exist").SerializeToString(), ) as resp: data = await resp.read() - print(data) - assert resp.status == 400 - assert json.loads(data) == { - "code": "invalid_argument", - "message": "content length is too large", - } + self.assertEqual(resp.status, 404) + self.assertEqual( + json.loads(data), + make_error_not_found("function 'does-not-exist' does not exist"), + ) @aiotest async def test_call_function_no_input(self): @@ -379,7 +444,7 @@ def my_function() -> str: return "Hello World!" ret = await my_function() - assert ret == "Hello World!" + self.assertEqual(ret, "Hello World!") @aiotest async def test_call_function_with_input(self): @@ -388,7 +453,66 @@ def my_function(name: str) -> str: return f"Hello world: {name}" ret = await my_function("52") - assert ret == "Hello world: 52" + self.assertEqual(ret, "Hello world: 52") + + @aiotest + async def test_call_function_raise_error(self): + @self.dispatch.function + def my_function(name: str) -> str: + raise ValueError("something went wrong!") + + with self.assertRaises(ValueError) as e: + await my_function("52") + + @aiotest + async def test_call_two_functions(self): + @self.dispatch.function + def echo(name: str) -> str: + return name + + @self.dispatch.function + def length(name: str) -> int: + return len(name) + + self.assertEqual(await echo("hello"), "hello") + self.assertEqual(await length("hello"), 5) + + # TODO: + # + # The declaration of nested functions in these tests causes CPython to + # generate cell objects since the local variables are referenced by multiple + # scopes. + # + # Maybe time to revisit https://github.com/dispatchrun/dispatch-py/pull/121 + # + # Alternatively, we could rewrite the test suite to use a global registry + # where we register each function once in the globla scope, so no cells need + # to be created. + + # @aiotest + # async def test_call_nested_function_with_result(self): + # @self.dispatch.function + # def echo(name: str) -> str: + # return name + + # @self.dispatch.function + # async def echo2(name: str) -> str: + # return await echo(name) + + # self.assertEqual(await echo2("hello"), "hello") + + # @aiotest + # async def test_call_nested_function_with_error(self): + # @self.dispatch.function + # def broken_function(name: str) -> str: + # raise ValueError("something went wrong!") + + # @self.dispatch.function + # async def working_function(name: str) -> str: + # return await broken_function(name) + + # with self.assertRaises(ValueError) as e: + # await working_function("hello") class ClientRequestContentLengthMissing(aiohttp.ClientRequest): @@ -402,3 +526,15 @@ class ClientRequestContentLengthTooLarge(aiohttp.ClientRequest): def update_headers(self, skip_auto_headers): super().update_headers(skip_auto_headers) self.headers["Content-Length"] = "16000001" + + +def make_error_invalid_argument(message: str) -> dict: + return make_error("invalid_argument", message) + + +def make_error_not_found(message: str) -> dict: + return make_error("not_found", message) + + +def make_error(code: str, message: str) -> dict: + return {"code": code, "message": message} From a674f071dbbfd5d304d48905ce69cae283e8417a Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 12 Jun 2024 14:51:55 -0700 Subject: [PATCH 06/34] fix compatibility with Python 3.8 Signed-off-by: Achille Roussel --- src/dispatch/flask.py | 1 + src/dispatch/test/__init__.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 24986826..8cece0e0 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -20,6 +20,7 @@ def read_root(): import asyncio import logging import threading + # from queue import Queue from typing import Optional, Union diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index aa5eec51..7669f5e7 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -4,7 +4,7 @@ import unittest from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Coroutine, Optional, TypeVar, overload +from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, overload import aiohttp from aiohttp import web @@ -83,7 +83,7 @@ def url(self): class Service(web.Application): - tasks: dict[str, asyncio.Task[CallResult]] + tasks: Dict[str, asyncio.Task[CallResult]] _session: Optional[aiohttp.ClientSession] = None def __init__(self, session: Optional[aiohttp.ClientSession] = None): @@ -279,6 +279,7 @@ async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: return asyncio.run(main(reg, fn)) + # TODO: these decorators still need work, until we figure out serialization # for cell objects, they are not very useful since the registry they receive # as argument cannot be used to register dispatch functions. @@ -295,6 +296,7 @@ def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: # # (WIP) + def function(fn: Callable[[Registry], Coroutine[Any, Any, None]]) -> Callable[[], None]: @wraps(fn) def wrapper(): @@ -322,6 +324,7 @@ def aiotest( This decorator is internal only, it shouldn't be exposed in the public API of this module. """ + @wraps(fn) def test(self): self.server_loop.run_until_complete(fn(self)) From 8dfc03a881c0b70ff004edabd5b00201f741e3f2 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Wed, 12 Jun 2024 15:04:50 -0700 Subject: [PATCH 07/34] fix for Python 3.8: asyncio.Task is not a generic type Signed-off-by: Achille Roussel --- src/dispatch/test/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 7669f5e7..07c306cc 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -83,7 +83,7 @@ def url(self): class Service(web.Application): - tasks: Dict[str, asyncio.Task[CallResult]] + tasks: Dict[str, asyncio.Task] _session: Optional[aiohttp.ClientSession] = None def __init__(self, session: Optional[aiohttp.ClientSession] = None): From 529823d8ef9081e102bfbbbdcd87bc741a6aa2b6 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 14:03:23 -0700 Subject: [PATCH 08/34] refactor: use composition, default registry, function service Signed-off-by: Achille Roussel --- examples/auto_retry/test_app.py | 2 +- examples/getting_started/test_app.py | 2 +- examples/github_stats/test_app.py | 2 +- src/dispatch/__init__.py | 19 +- src/dispatch/experimental/durable/function.py | 10 +- src/dispatch/experimental/durable/registry.py | 12 ++ src/dispatch/experimental/lambda_handler.py | 33 ++-- src/dispatch/fastapi.py | 101 +++++------ src/dispatch/flask.py | 39 +--- src/dispatch/function.py | 166 +++++++++++------- src/dispatch/http.py | 75 +++++++- src/dispatch/test/__init__.py | 140 +++++++-------- tests/dispatch/test_function.py | 13 +- 13 files changed, 348 insertions(+), 266 deletions(-) diff --git a/examples/auto_retry/test_app.py b/examples/auto_retry/test_app.py index fca4dc2f..8ce3f188 100644 --- a/examples/auto_retry/test_app.py +++ b/examples/auto_retry/test_app.py @@ -29,7 +29,7 @@ def test_app(self): dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. - dispatch.set_client(Client(api_url=dispatch_server.url)) + dispatch.registry.client = Client(api_url=dispatch_server.url) response = app_client.get("/") self.assertEqual(response.status_code, 200) diff --git a/examples/getting_started/test_app.py b/examples/getting_started/test_app.py index 16a7f8cb..a3345b92 100644 --- a/examples/getting_started/test_app.py +++ b/examples/getting_started/test_app.py @@ -28,7 +28,7 @@ def test_app(self): dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. - dispatch.set_client(Client(api_url=dispatch_server.url)) + dispatch.registry.client = Client(api_url=dispatch_server.url) response = app_client.get("/") self.assertEqual(response.status_code, 200) diff --git a/examples/github_stats/test_app.py b/examples/github_stats/test_app.py index 48440165..37ca0d84 100644 --- a/examples/github_stats/test_app.py +++ b/examples/github_stats/test_app.py @@ -28,7 +28,7 @@ def test_app(self): dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: # Use it when dispatching function calls. - dispatch.set_client(Client(api_url=dispatch_server.url)) + dispatch.registry.client = Client(api_url=dispatch_server.url) response = app_client.get("/") self.assertEqual(response.status_code, 200) diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 812d621a..788c6a37 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -12,7 +12,15 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race -from dispatch.function import Batch, Client, ClientError, Function, Registry, Reset +from dispatch.function import ( + Batch, + Client, + ClientError, + Function, + Registry, + Reset, + default_registry, +) from dispatch.http import Dispatch from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output @@ -43,15 +51,6 @@ P = ParamSpec("P") T = TypeVar("T") -_registry: Optional[Registry] = None - - -def default_registry(): - global _registry - if not _registry: - _registry = Registry() - return _registry - @overload def function(func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... diff --git a/src/dispatch/experimental/durable/function.py b/src/dispatch/experimental/durable/function.py index 87ccca42..a014935b 100644 --- a/src/dispatch/experimental/durable/function.py +++ b/src/dispatch/experimental/durable/function.py @@ -23,7 +23,12 @@ ) from . import frame as ext -from .registry import RegisteredFunction, lookup_function, register_function +from .registry import ( + RegisteredFunction, + lookup_function, + register_function, + unregister_function, +) TRACE = os.getenv("DISPATCH_TRACE", False) @@ -58,6 +63,9 @@ def __call__(self, *args, **kwargs): def __repr__(self) -> str: return f"DurableFunction({self.__qualname__})" + def unregister(self): + unregister_function(self.registered_fn.key) + def durable(fn: Callable) -> Callable: """Returns a "durable" function that creates serializable diff --git a/src/dispatch/experimental/durable/registry.py b/src/dispatch/experimental/durable/registry.py index 6ddac075..9250ec0d 100644 --- a/src/dispatch/experimental/durable/registry.py +++ b/src/dispatch/experimental/durable/registry.py @@ -106,6 +106,18 @@ def lookup_function(key: str) -> RegisteredFunction: return _REGISTRY[key] +def unregister_function(key: str): + """Unregister a function by key. + + Args: + key: Unique identifier for the function. + + Raises: + KeyError: A function has not been registered with this key. + """ + del _REGISTRY[key] + + def clear_functions(): """Clear functions clears the registry.""" _REGISTRY.clear() diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 6aeeaca6..2b098052 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -18,15 +18,16 @@ def handler(event, context): dispatch.handle(event, context, entrypoint="entrypoint") """ +import asyncio import base64 import json import logging -from typing import Optional +from typing import Optional, Union from awslambdaric.lambda_context import LambdaContext -from dispatch.asyncio import Runner from dispatch.function import Registry +from dispatch.http import FunctionService from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.status import Status @@ -34,27 +35,15 @@ def handler(event, context): logger = logging.getLogger(__name__) -class Dispatch(Registry): +class Dispatch(FunctionService): def __init__( self, - api_key: Optional[str] = None, - api_url: Optional[str] = None, + registry: Optional[Registry] = None, ): - """Initializes a Dispatch Lambda handler. - - Args: - api_key: Dispatch API key to use for authentication. Uses the value - of the DISPATCH_API_KEY environment variable by default. - - api_url: The URL of the Dispatch API to use. Uses the value of the - DISPATCH_API_URL environment variable if set, otherwise - defaults to the public Dispatch API (DEFAULT_API_URL). - - """ - + """Initializes a Dispatch Lambda handler.""" # We use a fake endpoint to initialize the base class. The actual endpoint (the Lambda ARN) # is only known when the handler is invoked. - super().__init__(endpoint="http://lambda", api_key=api_key, api_url=api_url) + super().__init__(registry) def handle( self, event: str, context: LambdaContext, entrypoint: Optional[str] = None @@ -63,7 +52,8 @@ def handle( # We override the endpoint of all registered functions before any execution. if context.invoked_function_arn: self.endpoint = context.invoked_function_arn - self.override_endpoint(self.endpoint) + # TODO: this might mutate the default registry, we should figure out a better way. + self.registry.endpoint = self.endpoint if not event: raise ValueError("event is required") @@ -87,14 +77,13 @@ def handle( ) try: - func = self.functions[req.function] + func = self.registry.functions[req.function] except KeyError: raise ValueError(f"function {req.function} not found") input = Input(req) try: - with Runner() as runner: - output = runner.run(func._primitive_call(input)) + output = asyncio.run(func._primitive_call(input)) except Exception: logger.error("function '%s' fatal error", req.function, exc_info=True) raise # FIXME diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 660ebf53..4b0ff9c0 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -25,26 +25,20 @@ def read_root(): import fastapi.responses from dispatch.function import Registry -from dispatch.http import ( - FunctionServiceError, - function_service_run, - validate_content_length, -) +from dispatch.http import FunctionService, FunctionServiceError, validate_content_length from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) -class Dispatch(Registry): +class Dispatch(FunctionService): """A Dispatch instance, powered by FastAPI.""" def __init__( self, app: fastapi.FastAPI, - endpoint: Optional[str] = None, + registry: Optional[Registry] = None, verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - api_key: Optional[str] = None, - api_url: Optional[str] = None, ): """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. @@ -53,9 +47,8 @@ def __init__( Args: app: The FastAPI app to configure. - endpoint: Full URL of the application the Dispatch instance will - be running on. Uses the value of the DISPATCH_ENDPOINT_URL - environment variable by default. + registry: A registry of functions to expose. If omitted, the default + registry is used. verification_key: Key to use when verifying signed requests. Uses the value of the DISPATCH_VERIFICATION_KEY environment variable @@ -64,13 +57,6 @@ def __init__( If not set, request signature verification is disabled (a warning will be logged by the constructor). - api_key: Dispatch API key to use for authentication. Uses the value of - the DISPATCH_API_KEY environment variable by default. - - api_url: The URL of the Dispatch API to use. Uses the value of the - DISPATCH_API_URL environment variable if set, otherwise - defaults to the public Dispatch API (DEFAULT_API_URL). - Raises: ValueError: If any of the required arguments are missing. """ @@ -78,49 +64,42 @@ def __init__( raise ValueError( "missing FastAPI app as first argument of the Dispatch constructor" ) - super().__init__(endpoint, api_key=api_key, api_url=api_url) - verification_key = parse_verification_key(verification_key, endpoint=endpoint) - function_service = _new_app(self, verification_key) - app.mount("/dispatch.sdk.v1.FunctionService", function_service) - - -def _new_app(function_registry: Registry, verification_key: Optional[Ed25519PublicKey]): - app = fastapi.FastAPI() - - @app.exception_handler(FunctionServiceError) - async def on_error(request: fastapi.Request, exc: FunctionServiceError): - # https://connectrpc.com/docs/protocol/#error-end-stream - return fastapi.responses.JSONResponse( - status_code=exc.status, content={"code": exc.code, "message": exc.message} - ) + super().__init__(registry, verification_key) + function_service = fastapi.FastAPI() + + @function_service.exception_handler(FunctionServiceError) + async def on_error(request: fastapi.Request, exc: FunctionServiceError): + # https://connectrpc.com/docs/protocol/#error-end-stream + return fastapi.responses.JSONResponse( + status_code=exc.status, + content={"code": exc.code, "message": exc.message}, + ) - @app.post( - # The endpoint for execution is hardcoded at the moment. If the service - # gains more endpoints, this should be turned into a dynamic dispatch - # like the official gRPC server does. - "/Run", - ) - async def execute(request: fastapi.Request): - valid, reason = validate_content_length( - int(request.headers.get("content-length", 0)) - ) - if not valid: - raise FunctionServiceError(400, "invalid_argument", reason) - - # Raw request body bytes are only available through the underlying - # starlette Request object's body method, which returns an awaitable, - # forcing execute() to be async. - data: bytes = await request.body() - - content = await function_service_run( - str(request.url), - request.method, - request.headers, - data, - function_registry, - verification_key, + @function_service.post( + # The endpoint for execution is hardcoded at the moment. If the service + # gains more endpoints, this should be turned into a dynamic dispatch + # like the official gRPC server does. + "/Run", ) + async def execute(request: fastapi.Request): + valid, reason = validate_content_length( + int(request.headers.get("content-length", 0)) + ) + if not valid: + raise FunctionServiceError(400, "invalid_argument", reason) + + # Raw request body bytes are only available through the underlying + # starlette Request object's body method, which returns an awaitable, + # forcing execute() to be async. + data: bytes = await request.body() + + content = await self.run( + str(request.url), + request.method, + request.headers, + await request.body(), + ) - return fastapi.Response(content=content, media_type="application/proto") + return fastapi.Response(content=content, media_type="application/proto") - return app + app.mount("/dispatch.sdk.v1.FunctionService", function_service) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 8cece0e0..7991f20a 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -27,26 +27,20 @@ def read_root(): from flask import Flask, make_response, request from dispatch.function import Registry -from dispatch.http import ( - FunctionServiceError, - function_service_run, - validate_content_length, -) +from dispatch.http import FunctionService, FunctionServiceError, validate_content_length from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) -class Dispatch(Registry): +class Dispatch(FunctionService): """A Dispatch instance, powered by Flask.""" def __init__( self, app: Flask, - endpoint: Optional[str] = None, + registry: Optional[Registry] = None, verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - api_key: Optional[str] = None, - api_url: Optional[str] = None, ): """Initialize a Dispatch endpoint, and integrate it into a Flask app. @@ -55,9 +49,8 @@ def __init__( Args: app: The Flask app to configure. - endpoint: Full URL of the application the Dispatch instance will - be running on. Uses the value of the DISPATCH_ENDPOINT_URL - environment variable by default. + registry: A registry of functions to expose. If omitted, the default + registry is used. verification_key: Key to use when verifying signed requests. Uses the value of the DISPATCH_VERIFICATION_KEY environment variable @@ -66,13 +59,6 @@ def __init__( If not set, request signature verification is disabled (a warning will be logged by the constructor). - api_key: Dispatch API key to use for authentication. Uses the value of - the DISPATCH_API_KEY environment variable by default. - - api_url: The URL of the Dispatch API to use. Uses the value of the - DISPATCH_API_URL environment variable if set, otherwise - defaults to the public Dispatch API (DEFAULT_API_URL). - Raises: ValueError: If any of the required arguments are missing. """ @@ -81,12 +67,7 @@ def __init__( "missing Flask app as first argument of the Dispatch constructor" ) - super().__init__(endpoint, api_key=api_key, api_url=api_url) - - self._verification_key = parse_verification_key( - verification_key, endpoint=endpoint - ) - + super().__init__(registry, verification_key) app.errorhandler(FunctionServiceError)(self._handle_error) app.post("/dispatch.sdk.v1.FunctionService/Run")(self._execute) @@ -134,16 +115,12 @@ def _execute(self): if not valid: return {"code": "invalid_argument", "message": reason}, 400 - data: bytes = request.get_data(cache=False) - content = asyncio.run( - function_service_run( + self.run( request.url, request.method, dict(request.headers), - data, - self, - self._verification_key, + request.get_data(cache=False), ) ) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 5ad6647e..8791297c 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -37,6 +37,9 @@ logger = logging.getLogger(__name__) +P = ParamSpec("P") +T = TypeVar("T") + class GlobalSession(aiohttp.ClientSession): async def __aexit__(self, *args): @@ -62,41 +65,40 @@ def current_session() -> aiohttp.ClientSession: class PrimitiveFunction: - __slots__ = ("_endpoint", "_client", "_name", "_primitive_func") - _endpoint: str - _client: Client + __slots__ = ("_registry", "_name", "_primitive_func") + _registry: str _name: str _primitive_function: PrimitiveFunctionType def __init__( self, - endpoint: str, - client: Client, + registry: Registry, name: str, primitive_func: PrimitiveFunctionType, ): - self._endpoint = endpoint - self._client = client + self._registry = registry.name self._name = name self._primitive_func = primitive_func @property def endpoint(self) -> str: - return self._endpoint - - @endpoint.setter - def endpoint(self, value: str): - self._endpoint = value + return self.registry.endpoint @property def name(self) -> str: return self._name + @property + def registry(self) -> Registry: + return lookup_registry(self._registry) + async def _primitive_call(self, input: Input) -> Output: return await self._primitive_func(input) async def _primitive_dispatch(self, input: Any = None) -> DispatchID: - [dispatch_id] = await self._client.dispatch([self._build_primitive_call(input)]) + [dispatch_id] = await self.registry.client.dispatch( + [self._build_primitive_call(input)] + ) return dispatch_id def _build_primitive_call( @@ -110,10 +112,6 @@ def _build_primitive_call( ) -P = ParamSpec("P") -T = TypeVar("T") - - class Function(PrimitiveFunction, Generic[P, T]): """Callable wrapper around a function meant to be used throughout the Dispatch Python SDK. @@ -123,12 +121,11 @@ class Function(PrimitiveFunction, Generic[P, T]): def __init__( self, - endpoint: str, - client: Client, + registry: Registry, name: str, primitive_func: PrimitiveFunctionType, ): - PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func) + PrimitiveFunction.__init__(self, registry, name, primitive_func) self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable( self._call_async ) @@ -144,9 +141,9 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: call = self.build_call(*args, **kwargs) - [dispatch_id] = await self._client.dispatch([call]) + [dispatch_id] = await self.registry.client.dispatch([call]) - return await self._client.wait(dispatch_id) + return await self.registry.client.wait(dispatch_id) def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: """Dispatch an asynchronous call to the function without @@ -189,29 +186,20 @@ def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs): class Registry: """Registry of functions.""" - __slots__ = ("functions", "endpoint", "client") + __slots__ = ("functions", "client", "_name", "_endpoint") def __init__( - self, - endpoint: Optional[str] = None, - api_key: Optional[str] = None, - api_url: Optional[str] = None, + self, name: str, client: Optional[Client] = None, endpoint: Optional[str] = None ): """Initialize a function registry. Args: - endpoint: URL of the endpoint that the function is accessible from. - Uses the value of the DISPATCH_ENDPOINT_URL environment variable - by default. + name: A unique name for the registry. - api_key: Dispatch API key to use for authentication when - dispatching calls to functions. Uses the value of the - DISPATCH_API_KEY environment variable by default. + endpoint: URL of the endpoint that the function is accessible from. - api_url: The URL of the Dispatch API to use when dispatching calls - to functions. Uses the value of the DISPATCH_API_URL environment - variable if set, otherwise defaults to the public Dispatch API - (DEFAULT_API_URL). + client: Client instance to use for dispatching calls to registered + functions. Defaults to creating a new client instance. Raises: ValueError: If any of the required arguments are missing. @@ -224,15 +212,45 @@ def __init__( raise ValueError( "missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable" ) - parsed_url = urlparse(endpoint) - if not parsed_url.netloc or not parsed_url.scheme: - raise ValueError( - f"{endpoint_from} must be a full URL with protocol and domain (e.g., https://example.com)" - ) logger.info("configuring Dispatch endpoint %s", endpoint) self.functions: Dict[str, PrimitiveFunction] = {} + self.client = client or Client() self.endpoint = endpoint - self.client = Client(api_key=api_key, api_url=api_url) + + if not name: + raise ValueError("missing registry name") + if name in _registries: + raise ValueError(f"registry with name '{name}' already exists") + self._name = name + _registries[name] = self + + def close(self): + """Closes the registry, removing it and all its functions from the + dispatch application.""" + name = self._name + if name: + self._name = "" + del _registries[name] + # TODO: remove registered functions + + @property + def name(self) -> str: + return self._name + + @property + def endpoint(self) -> str: + return self._endpoint + + @endpoint.setter + def endpoint(self, value: str): + parsed = urlparse(value) + if parsed.scheme not in ("http", "https"): + raise ValueError( + f"missing protocol scheme in registry endpoint URL: {value}" + ) + if not parsed.hostname: + raise ValueError(f"missing host in registry endpoint URL: {value}") + self._endpoint = value @overload def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... @@ -276,8 +294,7 @@ async def primitive_func(input: Input) -> Output: durable_primitive_func = durable(primitive_func) wrapped_func = Function[P, T]( - self.endpoint, - self.client, + self, name, durable_primitive_func, ) @@ -290,12 +307,7 @@ def primitive_function( """Decorator that registers primitive functions.""" name = primitive_func.__qualname__ logger.info("registering primitive function: %s", name) - wrapped_func = PrimitiveFunction( - self.endpoint, - self.client, - name, - primitive_func, - ) + wrapped_func = PrimitiveFunction(self, name, primitive_func) self._register(name, wrapped_func) return wrapped_func @@ -310,16 +322,50 @@ def batch(self): # -> Batch: # return self.client.batch() raise NotImplemented - def set_client(self, client: Client): - """Set the Client instance used to dispatch calls to registered functions.""" - # TODO: figure out a way to remove this method, it's only used in examples - self.client = client - for fn in self.functions.values(): - fn._client = client - def override_endpoint(self, endpoint: str): - for fn in self.functions.values(): - fn.endpoint = endpoint +_registries: Dict[str, Registry] = {} + +DEFAULT_REGISTRY_NAME: str = "default" +DEFAULT_REGISTRY: Optional[Registry] = None +"""The default registry for dispatch functions, used by dispatch applications +when no custom registry is provided. + +In most cases, applications do not need to create a custom registry, so this +one would be used by default. + +The default registry use DISPATCH_* environment variables for configuration, +or is uninitialized if they are not set. +""" + + +def default_registry() -> Registry: + """Returns the default registry for dispatch functions. + + The function initializes the default registry if it has not been initialized + yet, using the DISPATCH_* environment variables for configuration. + + Returns: + Registry: The default registry. + + Raises: + ValueError: If the DISPATCH_API_KEY or DISPATCH_ENDPOINT_URL environment + variables are missing. + """ + global DEFAULT_REGISTRY + if DEFAULT_REGISTRY is None: + DEFAULT_REGISTRY = Registry(DEFAULT_REGISTRY_NAME) + return DEFAULT_REGISTRY + + +def lookup_registry(name: str) -> Registry: + return default_registry() if name == DEFAULT_REGISTRY_NAME else _registries[name] + + +def set_default_registry(reg: Registry): + global DEFAULT_REGISTRY + global DEFAULT_REGISTRY_NAME + DEFAULT_REGISTRY = reg + DEFAULT_REGISTRY_NAME = reg.name class Client: diff --git a/src/dispatch/http.py b/src/dispatch/http.py index ccbcfe94..edbc9ca4 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -5,12 +5,25 @@ import os from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Coroutine, + Iterable, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, + overload, +) from aiohttp import web from http_message_signatures import InvalidSignature +from typing_extensions import ParamSpec, TypeAlias -from dispatch.function import Registry +from dispatch.function import Batch, Function, Registry, default_registry from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -24,6 +37,60 @@ logger = logging.getLogger(__name__) +P = ParamSpec("P") +T = TypeVar("T") + + +class FunctionService: + """FunctionService is an abstract class intended to be inherited by objects + that integrate dispatch with other server application frameworks. + + An application encapsulates a function Registry, and implements the API + common to all dispatch integrations. + """ + + def __init__( + self, + registry: Optional[Registry] = None, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + ): + self._registry = registry + self._verification_key = parse_verification_key( + verification_key, + endpoint=self.registry.endpoint, + ) + + @property + def registry(self) -> Registry: + return self._registry or default_registry() + + @property + def verification_key(self) -> Optional[Ed25519PublicKey]: + return self._verification_key + + @overload + def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... + + @overload + def function(self, func: Callable[P, T]) -> Function[P, T]: ... + + def function(self, func): + """Decorator that registers functions.""" + return self.registry.function(func) + + def batch(self) -> Batch: + return self.registry.batch() + + async def run(self, url, method, headers, data): + return await function_service_run( + url, + method, + headers, + data, + self.registry, + self.verification_key, + ) + class FunctionServiceError(Exception): __slots__ = ("status", "code", "message") @@ -44,7 +111,7 @@ def validate_content_length(content_length: int) -> Tuple[bool, str]: return True, "" -class FunctionService(BaseHTTPRequestHandler): +class FunctionServiceHTTPRequestHandler(BaseHTTPRequestHandler): def __init__( self, @@ -148,7 +215,7 @@ def __init__( ) def __call__(self, request, client_address, server): - return FunctionService( + return FunctionServiceHTTPRequestHandler( request, client_address, server, diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 07c306cc..d2cbfd5f 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -13,9 +13,15 @@ import dispatch.experimental.durable.registry from dispatch.function import Client as BaseClient -from dispatch.function import ClientError, Input, Output -from dispatch.function import Registry as BaseRegistry -from dispatch.http import Dispatch +from dispatch.function import ( + ClientError, + Input, + Output, + Registry, + default_registry, + set_default_registry, +) +from dispatch.http import Dispatch, FunctionService from dispatch.http import Server as BaseServer from dispatch.sdk.v1.call_pb2 import Call, CallResult from dispatch.sdk.v1.dispatch_pb2 import DispatchRequest, DispatchResponse @@ -46,7 +52,6 @@ ] P = ParamSpec("P") -R = TypeVar("R", bound=BaseRegistry) T = TypeVar("T") DISPATCH_ENDPOINT_URL = "http://127.0.0.1:0" @@ -62,17 +67,6 @@ def session(self) -> aiohttp.ClientSession: return aiohttp.ClientSession() -class Registry(BaseRegistry): - def __init__(self): - # placeholder values to initialize the base class prior to binding - # random ports. - super().__init__( - endpoint=DISPATCH_ENDPOINT_URL, - api_url=DISPATCH_API_URL, - api_key=DISPATCH_API_KEY, - ) - - class Server(BaseServer): def __init__(self, app: web.Application): super().__init__("127.0.0.1", 0, app) @@ -258,7 +252,8 @@ def session(self) -> aiohttp.ClientSession: return self._session -async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: +async def main(coro: Coroutine[Any, Any, None]) -> None: + reg = default_registry() api = Service() app = Dispatch(reg) try: @@ -268,7 +263,7 @@ async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: # ideal but it works for now. reg.client.api_url.value = backend.url reg.endpoint = server.url - await fn(reg) + await coro finally: await api.close() # TODO: let's figure out how to get rid of this global registry @@ -276,8 +271,8 @@ async def main(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: dispatch.experimental.durable.registry.clear_functions() -def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: - return asyncio.run(main(reg, fn)) +def run(coro: Coroutine[Any, Any, None]) -> None: + return asyncio.run(main(coro)) # TODO: these decorators still need work, until we figure out serialization @@ -297,20 +292,18 @@ def run(reg: R, fn: Callable[[R], Coroutine[Any, Any, None]]) -> None: # (WIP) -def function(fn: Callable[[Registry], Coroutine[Any, Any, None]]) -> Callable[[], None]: +def function(fn: Callable[[], Coroutine[Any, Any, None]]) -> Callable[[], None]: @wraps(fn) def wrapper(): - return run(Registry(), fn) + return run(fn()) return wrapper -def method( - fn: Callable[[T, Registry], Coroutine[Any, Any, None]] -) -> Callable[[T], None]: +def method(fn: Callable[[T], Coroutine[Any, Any, None]]) -> Callable[[T], None]: @wraps(fn) def wrapper(self: T): - return run(Registry(), lambda reg: fn(self, reg)) + return run(fn(self)) return wrapper @@ -332,6 +325,41 @@ def test(self): return test +_registry = Registry( + name=__name__, + endpoint=DISPATCH_ENDPOINT_URL, + client=Client(api_key=DISPATCH_API_KEY, api_url=DISPATCH_API_URL), +) + + +@_registry.function +def greet() -> str: + return "Hello World!" + + +@_registry.function +def greet_name(name: str) -> str: + return f"Hello world: {name}" + + +@_registry.function +def echo(name: str) -> str: + return name + + +@_registry.function +def length(name: str) -> int: + return len(name) + + +@_registry.function +def broken() -> str: + raise ValueError("something went wrong!") + + +set_default_registry(_registry) + + class TestCase(unittest.TestCase): """TestCase implements the generic test suite used in dispatch-py to test various integrations of the SDK with frameworks like FastAPI, Flask, etc... @@ -345,11 +373,7 @@ class TestCase(unittest.TestCase): more details). """ - def dispatch_test_init(self, api_key: str, api_url: str) -> BaseRegistry: - """Called to initialize each test case. The method returns the dispatch - function registry which can be used to register function instances - during tests. - """ + def dispatch_test_init(self, reg: Registry) -> str: raise NotImplementedError def dispatch_test_run(self): @@ -360,18 +384,14 @@ def dispatch_test_stop(self): def setUp(self): self.service = Service() + self.server = Server(self.service) self.server_loop = asyncio.new_event_loop() self.server_loop.run_until_complete(self.server.start()) - self.dispatch = self.dispatch_test_init( - api_key=DISPATCH_API_KEY, api_url=self.server.url - ) - self.dispatch.client = Client( - api_key=self.dispatch.client.api_key.value, - api_url=self.dispatch.client.api_url.value, - ) - + _registry.client.api_key.value = DISPATCH_API_KEY + _registry.client.api_url.value = self.server.url + _registry.endpoint = self.dispatch_test_init(_registry) self.client_thread = threading.Thread(target=self.dispatch_test_run) self.client_thread.start() @@ -385,20 +405,16 @@ def tearDown(self): # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. - dispatch.experimental.durable.registry.clear_functions() + # + # We can't erase the registry because user tests might have registered + # functions in the global scope that would be lost after the first test + # we run. + # + # dispatch.experimental.durable.registry.clear_functions() @property def function_service_run_url(self) -> str: - return f"{self.dispatch.endpoint}/dispatch.sdk.v1.FunctionService/Run" - - def test_register_duplicate_functions(self): - @self.dispatch.function - def my_function(): ... - - with self.assertRaises(ValueError): - - @self.dispatch.function - def my_function(): ... + return f"{_registry.endpoint}/dispatch.sdk.v1.FunctionService/Run" @aiotest async def test_content_length_missing(self): @@ -442,41 +458,21 @@ async def test_call_function_missing(self): @aiotest async def test_call_function_no_input(self): - @self.dispatch.function - def my_function() -> str: - return "Hello World!" - - ret = await my_function() + ret = await greet() self.assertEqual(ret, "Hello World!") @aiotest async def test_call_function_with_input(self): - @self.dispatch.function - def my_function(name: str) -> str: - return f"Hello world: {name}" - - ret = await my_function("52") + ret = await greet_name("52") self.assertEqual(ret, "Hello world: 52") @aiotest async def test_call_function_raise_error(self): - @self.dispatch.function - def my_function(name: str) -> str: - raise ValueError("something went wrong!") - with self.assertRaises(ValueError) as e: - await my_function("52") + await broken() @aiotest async def test_call_two_functions(self): - @self.dispatch.function - def echo(name: str) -> str: - return name - - @self.dispatch.function - def length(name: str) -> int: - return len(name) - self.assertEqual(await echo("hello"), "hello") self.assertEqual(await length("hello"), 5) diff --git a/tests/dispatch/test_function.py b/tests/dispatch/test_function.py index 3550b4b5..276a458e 100644 --- a/tests/dispatch/test_function.py +++ b/tests/dispatch/test_function.py @@ -1,10 +1,18 @@ import pickle -from dispatch.test import Registry +from dispatch.function import Client, Registry +from dispatch.test import DISPATCH_API_KEY, DISPATCH_API_URL, DISPATCH_ENDPOINT_URL def test_serializable(): - reg = Registry() + reg = Registry( + name=__name__, + endpoint=DISPATCH_ENDPOINT_URL, + client=Client( + api_key=DISPATCH_API_KEY, + api_url=DISPATCH_API_URL, + ), + ) @reg.function def my_function(): @@ -12,3 +20,4 @@ def my_function(): s = pickle.dumps(my_function) pickle.loads(s) + reg.close() From 6bfb086bb5696b22d710e1cb3a259ddccdba9fcd Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 14:11:54 -0700 Subject: [PATCH 09/34] remove unused variables --- src/dispatch/__init__.py | 1 - src/dispatch/experimental/lambda_handler.py | 2 +- src/dispatch/fastapi.py | 1 - src/dispatch/flask.py | 1 - src/dispatch/function.py | 3 --- src/dispatch/http.py | 3 --- src/dispatch/proto.py | 13 +++++++++++-- src/dispatch/test/__init__.py | 2 +- src/dispatch/test/client.py | 2 +- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 788c6a37..bd50a272 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -from concurrent import futures from http.server import ThreadingHTTPServer from typing import Any, Callable, Coroutine, Optional, TypeVar, overload from urllib.parse import urlsplit diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 2b098052..8990c6a1 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -22,7 +22,7 @@ def handler(event, context): import base64 import json import logging -from typing import Optional, Union +from typing import Optional from awslambdaric.lambda_context import LambdaContext diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 4b0ff9c0..bacdbdef 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -17,7 +17,6 @@ def read_root(): my_function.dispatch() """ -import asyncio import logging from typing import Optional, Union diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 7991f20a..1fb05d3e 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -19,7 +19,6 @@ def read_root(): import asyncio import logging -import threading # from queue import Queue from typing import Optional, Union diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 8791297c..5818aa79 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -6,7 +6,6 @@ import logging import os from functools import wraps -from types import CoroutineType from typing import ( Any, Awaitable, @@ -204,10 +203,8 @@ def __init__( Raises: ValueError: If any of the required arguments are missing. """ - endpoint_from = "endpoint argument" if not endpoint: endpoint = os.getenv("DISPATCH_ENDPOINT_URL") - endpoint_from = "DISPATCH_ENDPOINT_URL" if not endpoint: raise ValueError( "missing application endpoint: set it with the DISPATCH_ENDPOINT_URL environment variable" diff --git a/src/dispatch/http.py b/src/dispatch/http.py index edbc9ca4..14475ce1 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -2,15 +2,12 @@ import asyncio import logging -import os from datetime import timedelta from http.server import BaseHTTPRequestHandler from typing import ( Any, Callable, Coroutine, - Iterable, - List, Mapping, Optional, Tuple, diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index ffe6a10c..9576892b 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -3,7 +3,6 @@ import pickle from dataclasses import dataclass from traceback import format_exception -from types import TracebackType from typing import Any, Dict, List, Optional, Tuple import google.protobuf.any_pb2 @@ -79,7 +78,17 @@ def __init__(self, req: function_pb.RunRequest): self._has_input = req.HasField("input") if self._has_input: - self._input = _pb_any_unpack(req.input) + if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): + input_pb = google.protobuf.wrappers_pb2.BytesValue() + req.input.Unpack(input_pb) + input_bytes = input_pb.value + try: + self._input = pickle.loads(input_bytes) + except Exception: + self._input = input_bytes + else: + self._input = _pb_any_unpack(req.input) + else: if req.poll_result.coroutine_state: raise IncompatibleStateError # coroutine_state is deprecated diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index d2cbfd5f..9b21831d 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -4,7 +4,7 @@ import unittest from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, overload +from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar import aiohttp from aiohttp import web diff --git a/src/dispatch/test/client.py b/src/dispatch/test/client.py index b9f56ab7..6ff3ba88 100644 --- a/src/dispatch/test/client.py +++ b/src/dispatch/test/client.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Mapping, Optional, Protocol, Union +from typing import Optional import grpc From ecb9d337bdce02ef12b20fe18083fd6c63396fbd Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 14:12:20 -0700 Subject: [PATCH 10/34] enable nested function call test Signed-off-by: Achille Roussel --- src/dispatch/test/__init__.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 9b21831d..ef11729e 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -347,6 +347,11 @@ def echo(name: str) -> str: return name +@_registry.function +async def echo2(name: str) -> str: + return await echo(name) + + @_registry.function def length(name: str) -> int: return len(name) @@ -488,17 +493,9 @@ async def test_call_two_functions(self): # where we register each function once in the globla scope, so no cells need # to be created. - # @aiotest - # async def test_call_nested_function_with_result(self): - # @self.dispatch.function - # def echo(name: str) -> str: - # return name - - # @self.dispatch.function - # async def echo2(name: str) -> str: - # return await echo(name) - - # self.assertEqual(await echo2("hello"), "hello") + @aiotest + async def test_call_nested_function_with_result(self): + self.assertEqual(await echo2("hello"), "hello") # @aiotest # async def test_call_nested_function_with_error(self): From d9e73d4ed120c54fe94345562dff7445b62db3b9 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 14:16:06 -0700 Subject: [PATCH 11/34] fix formatting check for registry endpoint Signed-off-by: Achille Roussel --- src/dispatch/function.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 5818aa79..8eb960cf 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -241,10 +241,14 @@ def endpoint(self) -> str: @endpoint.setter def endpoint(self, value: str): parsed = urlparse(value) - if parsed.scheme not in ("http", "https"): + if not parsed.scheme: raise ValueError( f"missing protocol scheme in registry endpoint URL: {value}" ) + if parsed.scheme not in ("bridge", "http", "https"): + raise ValueError( + f"invalid protocol scheme in registry endpoint URL: {value}" + ) if not parsed.hostname: raise ValueError(f"missing host in registry endpoint URL: {value}") self._endpoint = value From b894337c0d4a1195bb46e1c087442a76d9818ca7 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 14:39:10 -0700 Subject: [PATCH 12/34] fix PrimitiveFunction.__call__, the method cannot be async Signed-off-by: Achille Roussel --- src/dispatch/function.py | 29 +++++++++++++++++++-------- src/dispatch/test/__init__.py | 37 +++++++++++------------------------ 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 8eb960cf..404225d5 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -132,17 +132,30 @@ def __init__( async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T: return await dispatch.coroutine.call(self.build_call(*args, **kwargs)) - async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + async def _call_dispatch(self, *args: P.args, **kwargs: P.kwargs) -> T: + call = self.build_call(*args, **kwargs) + client = self.registry.client + [dispatch_id] = await client.dispatch([call]) + return await client.wait(dispatch_id) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: """Call the function asynchronously (through Dispatch), and return a coroutine that can be awaited to retrieve the call result.""" + # Note: this method cannot be made `async`, otherwise Python creates + # ont additional wrapping layer of native coroutine that cannot be + # pickled and breaks serialization. + # + # The durable coroutine returned by calling _func_indirect must be + # returned as is. + # + # For cases where this method is called outside the context of a + # dispatch function, it still returns a native coroutine object, + # but that doesn't matter since there is no state serialization in + # that case. if in_function_call(): - return await self._func_indirect(*args, **kwargs) - - call = self.build_call(*args, **kwargs) - - [dispatch_id] = await self.registry.client.dispatch([call]) - - return await self.registry.client.wait(dispatch_id) + return self._func_indirect(*args, **kwargs) + else: + return self._call_dispatch(*args, **kwargs) def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: """Dispatch an asynchronous call to the function without diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index ef11729e..27c0165f 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -348,7 +348,7 @@ def echo(name: str) -> str: @_registry.function -async def echo2(name: str) -> str: +async def echo_nested(name: str) -> str: return await echo(name) @@ -362,6 +362,11 @@ def broken() -> str: raise ValueError("something went wrong!") +@_registry.function +async def broken_nested(name: str) -> str: + return await broken() + + set_default_registry(_registry) @@ -481,34 +486,14 @@ async def test_call_two_functions(self): self.assertEqual(await echo("hello"), "hello") self.assertEqual(await length("hello"), 5) - # TODO: - # - # The declaration of nested functions in these tests causes CPython to - # generate cell objects since the local variables are referenced by multiple - # scopes. - # - # Maybe time to revisit https://github.com/dispatchrun/dispatch-py/pull/121 - # - # Alternatively, we could rewrite the test suite to use a global registry - # where we register each function once in the globla scope, so no cells need - # to be created. - @aiotest async def test_call_nested_function_with_result(self): - self.assertEqual(await echo2("hello"), "hello") + self.assertEqual(await echo_nested("hello"), "hello") - # @aiotest - # async def test_call_nested_function_with_error(self): - # @self.dispatch.function - # def broken_function(name: str) -> str: - # raise ValueError("something went wrong!") - - # @self.dispatch.function - # async def working_function(name: str) -> str: - # return await broken_function(name) - - # with self.assertRaises(ValueError) as e: - # await working_function("hello") + @aiotest + async def test_call_nested_function_with_error(self): + with self.assertRaises(ValueError) as e: + await broken_nested("hello") class ClientRequestContentLengthMissing(aiohttp.ClientRequest): From b4e2db4a5788514c1d1408489533ffdd2c1f7ab1 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 16:41:10 -0700 Subject: [PATCH 13/34] more fixes after rebase Signed-off-by: Achille Roussel --- src/dispatch/proto.py | 12 +-- src/dispatch/test/__init__.py | 1 + tests/dispatch/test_scheduler.py | 20 +---- tests/test_aiohttp.py | 124 ------------------------------- 4 files changed, 3 insertions(+), 154 deletions(-) delete mode 100644 tests/test_aiohttp.py diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 9576892b..9af3631d 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -78,17 +78,7 @@ def __init__(self, req: function_pb.RunRequest): self._has_input = req.HasField("input") if self._has_input: - if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): - input_pb = google.protobuf.wrappers_pb2.BytesValue() - req.input.Unpack(input_pb) - input_bytes = input_pb.value - try: - self._input = pickle.loads(input_bytes) - except Exception: - self._input = input_bytes - else: - self._input = _pb_any_unpack(req.input) - + self._input = _pb_any_unpack(req.input) else: if req.poll_result.coroutine_state: raise IncompatibleStateError # coroutine_state is deprecated diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 27c0165f..e6de18f4 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -231,6 +231,7 @@ def make_request(call: Call) -> RunRequest: root_dispatch_id=root_dispatch_id, poll_result=PollResult( coroutine_state=res.poll.coroutine_state, + typed_coroutine_state=res.poll.typed_coroutine_state, results=results, ), ) diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index e82d0db4..c15ed848 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -464,7 +464,7 @@ async def resume( poll = assert_poll(prev_output) input = Input.from_poll_results( main.__qualname__, - poll.coroutine_state, + any_unpickle(poll.typed_coroutine_state), call_results, Error.from_exception(poll_error) if poll_error else None, ) @@ -492,30 +492,12 @@ def assert_exit_result_value(output: Output, expect: Any): assert expect == any_unpickle(result.output) -<<<<<<< HEAD - def resume( - self, - main: Callable, - prev_output: Output, - call_results: List[CallResult], - poll_error: Optional[Exception] = None, - ): - poll = self.assert_poll(prev_output) - input = Input.from_poll_results( - main.__qualname__, - any_unpickle(poll.typed_coroutine_state), - call_results, - Error.from_exception(poll_error) if poll_error else None, - ) - return self.runner.run(OneShotScheduler(main).run(input)) -======= def assert_exit_result_error( output: Output, expect: Type[Exception], message: Optional[str] = None ): result = assert_exit_result(output) assert not result.HasField("output") assert result.HasField("error") ->>>>>>> 626d02d (aiohttp: refactor internals to use asyncio throughout the SDK) error = Error._from_proto(result.error).to_exception() assert error.__class__ == expect diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py deleted file mode 100644 index 1f535436..00000000 --- a/tests/test_aiohttp.py +++ /dev/null @@ -1,124 +0,0 @@ -import asyncio -import base64 -import os -import pickle -import struct -import threading -import unittest -from typing import Any, Tuple -from unittest import mock - -import fastapi -import google.protobuf.any_pb2 -import google.protobuf.wrappers_pb2 -import httpx -from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey - -import dispatch.test.httpx -from dispatch.aiohttp import Dispatch, Server -from dispatch.asyncio import Runner -from dispatch.experimental.durable.registry import clear_functions -from dispatch.function import Arguments, Error, Function, Input, Output, Registry -from dispatch.proto import _any_unpickle as any_unpickle -from dispatch.proto import _pb_any_pickle as any_pickle -from dispatch.sdk.v1 import call_pb2 as call_pb -from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.signature import parse_verification_key, public_key_from_pem -from dispatch.status import Status -from dispatch.test import EndpointClient - -public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----" -public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----" -public_key = public_key_from_pem(public_key_pem) -public_key_bytes = public_key.public_bytes_raw() -public_key_b64 = base64.b64encode(public_key_bytes) - -from datetime import datetime - - -def run(runner: Runner, server: Server, ready: threading.Event): - try: - with runner: - runner.run(serve(server, ready)) - except RuntimeError as e: - pass # silence errors triggered by stopping the loop after tests are done - - -async def serve(server: Server, ready: threading.Event): - async with server: - ready.set() # allow the test to continue after the server started - await asyncio.Event().wait() - - -class TestAIOHTTP(unittest.TestCase): - def setUp(self): - ready = threading.Event() - self.runner = Runner() - - host = "127.0.0.1" - port = 9997 - - self.endpoint = f"http://{host}:{port}" - self.dispatch = Dispatch( - Registry( - endpoint=self.endpoint, - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", - ), - ) - - self.client = httpx.Client(timeout=1.0) - self.server = Server(host, port, self.dispatch) - self.thread = threading.Thread( - target=lambda: run(self.runner, self.server, ready) - ) - self.thread.start() - ready.wait() - - def tearDown(self): - loop = self.runner.get_loop() - loop.call_soon_threadsafe(loop.stop) - self.thread.join(timeout=1.0) - self.client.close() - - def test_content_length_missing(self): - resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run") - body = resp.read() - self.assertEqual(resp.status_code, 400) - self.assertEqual( - body, b'{"code":"invalid_argument","message":"content length is required"}' - ) - - def test_content_length_too_large(self): - resp = self.client.post( - f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run", - data={"msg": "a" * 16_000_001}, - ) - body = resp.read() - self.assertEqual(resp.status_code, 400) - self.assertEqual( - body, b'{"code":"invalid_argument","message":"content length is too large"}' - ) - - def test_simple_request(self): - @self.dispatch.registry.primitive_function - async def my_function(input: Input) -> Output: - return Output.value( - f"You told me: '{input.input}' ({len(input.input)} characters)" - ) - - http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) - client = EndpointClient(http_client) - - req = function_pb.RunRequest( - function=my_function.name, - input=any_pickle("Hello World!"), - ) - - resp = client.run(req) - - self.assertIsInstance(resp, function_pb.RunResponse) - - output = any_unpickle(resp.exit.result.output) - - self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") From 3d0b0bbf8a4762681cae7b15fa1af159e8d0cd22 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 17:21:48 -0700 Subject: [PATCH 14/34] cleanup Signed-off-by: Achille Roussel --- src/dispatch/fastapi.py | 4 +- src/dispatch/flask.py | 74 +++-------------------------------- src/dispatch/test/__init__.py | 28 +------------ 3 files changed, 9 insertions(+), 97 deletions(-) diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index bacdbdef..3abf7b1a 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -6,7 +6,7 @@ from dispatch.fastapi import Dispatch app = fastapi.FastAPI() - dispatch = Dispatch(app, api_key="test-key") + dispatch = Dispatch(app) @dispatch.function def my_function(): @@ -80,7 +80,7 @@ async def on_error(request: fastapi.Request, exc: FunctionServiceError): # like the official gRPC server does. "/Run", ) - async def execute(request: fastapi.Request): + async def run(request: fastapi.Request): valid, reason = validate_content_length( int(request.headers.get("content-length", 0)) ) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 1fb05d3e..5df37ac7 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -6,7 +6,7 @@ from dispatch.flask import Dispatch app = Flask(__name__) - dispatch = Dispatch(app, api_key="test-key") + dispatch = Dispatch(app) @dispatch.function def my_function(): @@ -65,51 +65,14 @@ def __init__( raise ValueError( "missing Flask app as first argument of the Dispatch constructor" ) - super().__init__(registry, verification_key) - app.errorhandler(FunctionServiceError)(self._handle_error) - app.post("/dispatch.sdk.v1.FunctionService/Run")(self._execute) - - # TODO: earlier experiment I ran because it seemed like tasks created - # by the /Dispatch endpoint were canceled when calls to /Wait were made. - # - # After further investigation, it might have been caused by a bug when - # setting the thread local state indicating that we are being invoked - # from a scheduler thread, which resulted in unnecessary dispatch calls. - # - # I'm keeping the code around for now in case it ends up being needed in - # the short term. Feel free to remove if you run into this comment and - # it's no longer relevant. - # --- - # Here we have to use one event loop for the whole application to allow - # tasks spawned by request handlers to persist after the request is done. - # - # This is essential for tests to pass when using the /Dispatch and /Wait - # endpoints to wait on function results. - # self._loop = asyncio.new_event_loop() - # self._thread = threading.Thread(target=self._run_event_loop) - # self._thread.start() - - # def close(self): - # self._loop.call_soon_threadsafe(self._loop.stop) - # self._thread.join() - - # def __enter__(self): - # return self - - # def __exit__(self, exc_type, exc_value, traceback): - # self.close() - - # def _run_event_loop(self): - # asyncio.set_event_loop(self._loop) - # self._loop.run_forever() - # self._loop.run_until_complete(self._loop.shutdown_asyncgens()) - # self._loop.close() - - def _handle_error(self, exc: FunctionServiceError): + app.errorhandler(FunctionServiceError)(self._on_error) + app.post("/dispatch.sdk.v1.FunctionService/Run")(self._run) + + def _on_error(self, exc: FunctionServiceError): return {"code": exc.code, "message": exc.message}, exc.status - def _execute(self): + def _run(self): valid, reason = validate_content_length(request.content_length or 0) if not valid: return {"code": "invalid_argument", "message": reason}, 400 @@ -123,31 +86,6 @@ def _execute(self): ) ) - # queue = Queue[asyncio.Task](maxsize=1) - # - # url, method, headers = request.url, request.method, dict(request.headers) - # def execute_task(): - # task = self._loop.create_task( - # function_service_run( - # url, - # method, - # headers, - # data, - # self, - # self._verification_key, - # ) - # ) - # task.add_done_callback(queue.put) - - # self._loop.call_soon_threadsafe(execute_task) - # task: asyncio.Task = queue.get() - - # exception = task.exception() - # if exception is not None: - # raise exception - - # content: bytes = task.result() - res = make_response(content) res.content_type = "application/proto" return res diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index e6de18f4..ad586842 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -269,30 +269,13 @@ async def main(coro: Coroutine[Any, Any, None]) -> None: await api.close() # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. - dispatch.experimental.durable.registry.clear_functions() + # dispatch.experimental.durable.registry.clear_functions() def run(coro: Coroutine[Any, Any, None]) -> None: return asyncio.run(main(coro)) -# TODO: these decorators still need work, until we figure out serialization -# for cell objects, they are not very useful since the registry they receive -# as argument cannot be used to register dispatch functions. -# -# The simplest would be to use a global registry for external application tests, -# maybe we can figure out a way to make this easy with a syntax like: -# -# import main -# import dispatch.test -# -# @dispatch.test.function(main.dispatch) -# async def test_something(): -# ... -# -# (WIP) - - def function(fn: Callable[[], Coroutine[Any, Any, None]]) -> Callable[[], None]: @wraps(fn) def wrapper(): @@ -414,15 +397,6 @@ def tearDown(self): self.server_loop.run_until_complete(self.server_loop.shutdown_asyncgens()) self.server_loop.close() - # TODO: let's figure out how to get rid of this global registry - # state at some point, which forces tests to be run sequentially. - # - # We can't erase the registry because user tests might have registered - # functions in the global scope that would be lost after the first test - # we run. - # - # dispatch.experimental.durable.registry.clear_functions() - @property def function_service_run_url(self) -> str: return f"{_registry.endpoint}/dispatch.sdk.v1.FunctionService/Run" From f0a54e2620d9e0151d4310db1d8e56e99a76a8c2 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 17:22:48 -0700 Subject: [PATCH 15/34] cleanup Signed-off-by: Achille Roussel --- src/dispatch/flask.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index 5df37ac7..ffd6c923 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -19,8 +19,6 @@ def read_root(): import asyncio import logging - -# from queue import Queue from typing import Optional, Union from flask import Flask, make_response, request From 60c231290cf0451cbc655d3edf7ea90e2d763302 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 12:29:23 -0700 Subject: [PATCH 16/34] try fixing compat with 3.8 Signed-off-by: Achille Roussel --- tests/test_fastapi.py | 6 +++++- tests/test_http.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 7eceb5ff..5cd621e5 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,5 +1,6 @@ import asyncio import socket +import sys from typing import Any, Optional import fastapi @@ -56,7 +57,10 @@ def dispatch_test_init(self, reg: Registry) -> str: self.sockets = [sock] self.uvicorn = uvicorn.Server(config) self.runner = Runner() - self.event = asyncio.Event() + if sys.version_info > (3, 8): + self.event = asyncio.Event() + else: + self.event = asyncio.Event(loop=self.runner.get_loop()) return f"http://{host}:{port}" def dispatch_test_run(self): diff --git a/tests/test_http.py b/tests/test_http.py index 34d27ab4..9271b229 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,5 +1,6 @@ import asyncio import socket +import sys from http.server import HTTPServer import dispatch.test @@ -41,11 +42,15 @@ def dispatch_test_init(self, reg: Registry) -> str: host = "127.0.0.1" port = 0 - self.aiowait = asyncio.Event() self.aioloop = Runner() self.aiohttp = Server(host, port, Dispatch(reg)) self.aioloop.run(self.aiohttp.start()) + if sys.version_info > (3, 8): + self.aiowait = asyncio.Event() + else: + self.aiowait = asyncio.Event(loop=self.aioloop.get_loop()) + return f"http://{self.aiohttp.host}:{self.aiohttp.port}" def dispatch_test_run(self): From c053298418009f552b2f52c1f7a4366c08a6d459 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 13:21:16 -0700 Subject: [PATCH 17/34] fix python version check Signed-off-by: Achille Roussel --- tests/test_fastapi.py | 2 +- tests/test_http.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 5cd621e5..9c067ef4 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -57,7 +57,7 @@ def dispatch_test_init(self, reg: Registry) -> str: self.sockets = [sock] self.uvicorn = uvicorn.Server(config) self.runner = Runner() - if sys.version_info > (3, 8): + if sys.version_info >= (3, 9): self.event = asyncio.Event() else: self.event = asyncio.Event(loop=self.runner.get_loop()) diff --git a/tests/test_http.py b/tests/test_http.py index 9271b229..69ad3654 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -46,7 +46,7 @@ def dispatch_test_init(self, reg: Registry) -> str: self.aiohttp = Server(host, port, Dispatch(reg)) self.aioloop.run(self.aiohttp.start()) - if sys.version_info > (3, 8): + if sys.version_info >= (3, 9): self.aiowait = asyncio.Event() else: self.aiowait = asyncio.Event(loop=self.aioloop.get_loop()) From d6fa55596a258885deb98b78e9110a55291defe4 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 14:48:49 -0700 Subject: [PATCH 18/34] emulate wait capability Signed-off-by: Achille Roussel --- examples/github_stats/app.py | 27 +++++----------- examples/github_stats/test_app.py | 52 ------------------------------- src/dispatch/__init__.py | 49 ++++++++++++++++++----------- src/dispatch/function.py | 46 +++++++++++++++++---------- src/dispatch/http.py | 15 +++++++-- src/dispatch/test/__init__.py | 33 ++++++++++---------- 6 files changed, 96 insertions(+), 126 deletions(-) delete mode 100644 examples/github_stats/test_app.py diff --git a/examples/github_stats/app.py b/examples/github_stats/app.py index 513bb24c..996743d1 100644 --- a/examples/github_stats/app.py +++ b/examples/github_stats/app.py @@ -14,16 +14,9 @@ """ +import dispatch import httpx -from fastapi import FastAPI - from dispatch.error import ThrottleError -from dispatch.fastapi import Dispatch - -app = FastAPI() - -dispatch = Dispatch(app) - def get_gh_api(url): print(f"GET {url}") @@ -36,21 +29,21 @@ def get_gh_api(url): @dispatch.function -async def get_repo_info(repo_owner: str, repo_name: str): +async def get_repo_info(repo_owner: str, repo_name: str) -> dict: url = f"https://api.github.com/repos/{repo_owner}/{repo_name}" repo_info = get_gh_api(url) return repo_info @dispatch.function -async def get_contributors(repo_info: dict): +async def get_contributors(repo_info: dict) -> list[dict]: url = repo_info["contributors_url"] contributors = get_gh_api(url) return contributors @dispatch.function -async def main(): +async def main() -> list[dict]: repo_info = await get_repo_info("dispatchrun", "coroutine") print( f"""Repository: {repo_info['full_name']} @@ -58,13 +51,9 @@ async def main(): Watchers: {repo_info['watchers_count']} Forks: {repo_info['forks_count']}""" ) - - contributors = await get_contributors(repo_info) - print(f"Contributors: {len(contributors)}") - return + return await get_contributors(repo_info) -@app.get("/") -def root(): - main.dispatch() - return "OK" +if __name__ == "__main__": + contributors = dispatch.run(main()) + print(f"Contributors: {len(contributors)}") diff --git a/examples/github_stats/test_app.py b/examples/github_stats/test_app.py deleted file mode 100644 index 37ca0d84..00000000 --- a/examples/github_stats/test_app.py +++ /dev/null @@ -1,52 +0,0 @@ -# This file is not part of the example. It is a test file to ensure the example -# works as expected during the CI process. - - -import os -import unittest -from unittest import mock - -from dispatch.function import Client -from dispatch.test import DispatchServer, DispatchService, EndpointClient -from dispatch.test.fastapi import http_client - - -class TestGithubStats(unittest.TestCase): - @mock.patch.dict( - os.environ, - { - "DISPATCH_ENDPOINT_URL": "http://function-service", - "DISPATCH_API_KEY": "0000000000000000", - }, - ) - def test_app(self): - from .app import app, dispatch - - # Setup a fake Dispatch server. - app_client = http_client(app) - endpoint_client = EndpointClient(app_client) - dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) - with DispatchServer(dispatch_service) as dispatch_server: - # Use it when dispatching function calls. - dispatch.registry.client = Client(api_url=dispatch_server.url) - - response = app_client.get("/") - self.assertEqual(response.status_code, 200) - - while dispatch_service.queue: - dispatch_service.dispatch_calls() - - # Three unique functions were called, with five total round-trips. - # The main function is called initially, and then polls - # twice, for three total round-trips. There's one round-trip - # to get_repo_info and one round-trip to get_contributors. - self.assertEqual( - 3, len(dispatch_service.roundtrips) - ) # 3 unique functions were called - self.assertEqual( - 5, - sum( - len(roundtrips) - for roundtrips in dispatch_service.roundtrips.values() - ), - ) diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index bd50a272..ea0090a2 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import os from http.server import ThreadingHTTPServer from typing import Any, Callable, Coroutine, Optional, TypeVar, overload @@ -20,7 +21,7 @@ Reset, default_registry, ) -from dispatch.http import Dispatch +from dispatch.http import Dispatch, Server from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output from dispatch.status import Status @@ -63,7 +64,21 @@ def function(func): return default_registry().function(func) -def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwargs): +async def main(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: + address = addr or str(os.environ.get("DISPATCH_ENDPOINT_ADDR")) or "localhost:8000" + parsed_url = urlsplit("//" + address) + + host = parsed_url.hostname or "" + port = parsed_url.port or 0 + + reg = default_registry() + app = Dispatch(reg) + + async with Server(host, port, app) as server: + return await coro + + +def run(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: """Run the default dispatch server. The default server uses a function registry where functions tagged by the `@dispatch.function` decorator are registered. @@ -73,27 +88,23 @@ def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwa to the Dispatch bridge API. Args: - init: An initialization function called after binding the server address - but before entering the event loop to handle requests. - - args: Positional arguments to pass to the entrypoint. + coro: The coroutine to run as the entrypoint, the function returns + when the coroutine returns. - kwargs: Keyword arguments to pass to the entrypoint. + addr: The address to bind the server to. If not provided, the server + will bind to the address specified by the `DISPATCH_ENDPOINT_ADDR` + environment variable. If the environment variable is not set, the + server will bind to `localhost:8000`. Returns: - The return value of the entrypoint function. + The value returned by the coroutine. """ - address = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000") - parsed_url = urlsplit("//" + address) - server_address = (parsed_url.hostname or "", parsed_url.port or 0) - server = ThreadingHTTPServer(server_address, Dispatch(default_registry())) - try: - if init is not None: - init(*args, **kwargs) - server.serve_forever() - finally: - server.shutdown() - server.server_close() + return asyncio.run(main(coro, addr)) + + +def run_forever(): + """Run the default dispatch server forever.""" + return run(asyncio.Event().wait()) def batch() -> Batch: diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 404225d5..98442d0e 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -382,6 +382,10 @@ def set_default_registry(reg: Registry): DEFAULT_REGISTRY_NAME = reg.name +# TODO: this is a temporary solution to track inflight tasks and allow waiting +# for results. +_calls: Dict[str, asyncio.Future] = {} + class Client: """Client for the Dispatch API.""" @@ -469,6 +473,11 @@ async def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: resp = dispatch_pb.DispatchResponse() resp.ParseFromString(data) + # TODO: remove when we implemented the wait endpoint in the server + for dispatch_id in resp.dispatch_ids: + if dispatch_id not in _calls: + _calls[dispatch_id] = asyncio.Future() + dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -479,23 +488,26 @@ async def dispatch(self, calls: Iterable[Call]) -> List[DispatchID]: return dispatch_ids async def wait(self, dispatch_id: DispatchID) -> Any: - (url, headers, timeout) = self.request("/dispatch.sdk.v1.DispatchService/Wait") - data = dispatch_id.encode("utf-8") - - async with self.session() as session: - async with session.post( - url, headers=headers, data=data, timeout=timeout - ) as res: - data = await res.read() - self._check_response(res.status, data) - - resp = call_pb.CallResult() - resp.ParseFromString(data) - - result = CallResult._from_proto(resp) - if result.error is not None: - raise result.error.to_exception() - return result.output + # (url, headers, timeout) = self.request("/dispatch.sdk.v1.DispatchService/Wait") + # data = dispatch_id.encode("utf-8") + + # async with self.session() as session: + # async with session.post( + # url, headers=headers, data=data, timeout=timeout + # ) as res: + # data = await res.read() + # self._check_response(res.status, data) + + # resp = call_pb.CallResult() + # resp.ParseFromString(data) + + # result = CallResult._from_proto(resp) + # if result.error is not None: + # raise result.error.to_exception() + # return result.output + + future = _calls[dispatch_id] + return await future def _check_response(self, status: int, data: bytes): if status == 200: diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 14475ce1..b9937d08 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -20,8 +20,8 @@ from http_message_signatures import InvalidSignature from typing_extensions import ParamSpec, TypeAlias -from dispatch.function import Batch, Function, Registry, default_registry -from dispatch.proto import Input +from dispatch.function import Batch, Function, Registry, default_registry, _calls +from dispatch.proto import CallResult, Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( CaseInsensitiveDict, @@ -78,7 +78,7 @@ def function(self, func): def batch(self) -> Batch: return self.registry.batch() - async def run(self, url, method, headers, data): + async def run(self, url: str, method: str, headers: Mapping[str, str], data: bytes) -> bytes: return await function_service_run( url, method, @@ -380,6 +380,9 @@ async def function_service_run( response = output._message status = Status(response.status) + if req.dispatch_id not in _calls: + _calls[req.dispatch_id] = asyncio.Future() + if response.HasField("poll"): logger.debug( "function '%s' polling with %d call(s)", @@ -392,6 +395,12 @@ async def function_service_run( logger.debug("function '%s' exiting with no result", req.function) else: result = exit.result + call_result = CallResult._from_proto(result) + call_future = _calls[req.dispatch_id] + if call_result.error is not None: + call_future.set_exception(call_result.error.to_exception()) + else: + call_future.set_result(call_result.output) if result.HasField("output"): logger.debug("function '%s' exiting with output value", req.function) elif result.HasField("error"): diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index ad586842..1b93ff47 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -58,6 +58,7 @@ DISPATCH_API_URL = "http://127.0.0.1:0" DISPATCH_API_KEY = "916CC3D280BB46DDBDA984B3DD10059A" +_dispatch_ids = (str(i) for i in range(2**32 - 1)) class Client(BaseClient): def session(self) -> aiohttp.ClientSession: @@ -75,14 +76,12 @@ def __init__(self, app: web.Application): def url(self): return f"http://{self.host}:{self.port}" - class Service(web.Application): tasks: Dict[str, asyncio.Task] _session: Optional[aiohttp.ClientSession] = None def __init__(self, session: Optional[aiohttp.ClientSession] = None): super().__init__() - self.dispatch_ids = (str(i) for i in range(2**32 - 1)) self.tasks = {} self.add_routes( [ @@ -126,7 +125,7 @@ async def handle_wait_request(self, request: web.Request): ) async def dispatch(self, req: DispatchRequest) -> DispatchResponse: - dispatch_ids = [next(self.dispatch_ids) for _ in req.calls] + dispatch_ids = [next(_dispatch_ids) for _ in req.calls] for call, dispatch_id in zip(req.calls, dispatch_ids): self.tasks[dispatch_id] = asyncio.create_task( @@ -208,19 +207,21 @@ def make_request(call: Call) -> RunRequest: ) # TODO: enforce poll limits - results = await asyncio.gather( - *[ - self.call( - call=subcall, - dispatch_id=subcall_dispatch_id, - parent_dispatch_id=dispatch_id, - root_dispatch_id=root_dispatch_id, - ) - for subcall, subcall_dispatch_id in zip( - res.poll.calls, next(self.dispatch_ids) - ) - ] - ) + subcall_dispatch_ids = [next(_dispatch_ids) for _ in res.poll.calls] + + subcalls = [ + self.call( + call=subcall, + dispatch_id=subcall_dispatch_id, + parent_dispatch_id=dispatch_id, + root_dispatch_id=root_dispatch_id, + ) + for subcall, subcall_dispatch_id in zip( + res.poll.calls, subcall_dispatch_ids + ) + ] + + results = await asyncio.gather(*subcalls) req = RunRequest( function=req.function, From ccc4ab8e63ad2f4000674fddfdc161aab5ab44f5 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 15:00:07 -0700 Subject: [PATCH 19/34] rewrite examples to use dispatch.run Signed-off-by: Achille Roussel --- examples/auto_retry.py | 23 +++++ examples/auto_retry/__init__.py | 0 examples/auto_retry/app.py | 64 -------------- examples/auto_retry/test_app.py | 51 ----------- examples/{fanout => }/fanout.py | 43 ++-------- examples/fanout/__init__.py | 0 examples/fanout/test_fanout.py | 19 ----- examples/getting_started.py | 16 ++++ examples/getting_started/__init__.py | 0 examples/getting_started/app.py | 85 ------------------- examples/getting_started/test_app.py | 40 --------- .../{github_stats/app.py => github_stats.py} | 0 examples/github_stats/__init__.py | 0 13 files changed, 47 insertions(+), 294 deletions(-) create mode 100644 examples/auto_retry.py delete mode 100644 examples/auto_retry/__init__.py delete mode 100644 examples/auto_retry/app.py delete mode 100644 examples/auto_retry/test_app.py rename examples/{fanout => }/fanout.py (50%) delete mode 100644 examples/fanout/__init__.py delete mode 100644 examples/fanout/test_fanout.py create mode 100644 examples/getting_started.py delete mode 100644 examples/getting_started/__init__.py delete mode 100644 examples/getting_started/app.py delete mode 100644 examples/getting_started/test_app.py rename examples/{github_stats/app.py => github_stats.py} (100%) delete mode 100644 examples/github_stats/__init__.py diff --git a/examples/auto_retry.py b/examples/auto_retry.py new file mode 100644 index 00000000..3d6b8bfe --- /dev/null +++ b/examples/auto_retry.py @@ -0,0 +1,23 @@ +import dispatch +import random +import requests + +rng = random.Random(2) + +def third_party_api_call(x: int) -> str: + # Simulate a third-party API call that fails. + print(f"Simulating third-party API call with {x}") + if x < 3: + raise requests.RequestException("Simulated failure") + else: + return "SUCCESS" + + +# Use the `dispatch.function` decorator to declare a stateful function. +@dispatch.function +def application() -> str: + x = rng.randint(0, 5) + return third_party_api_call(x) + + +dispatch.run(application()) diff --git a/examples/auto_retry/__init__.py b/examples/auto_retry/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/auto_retry/app.py b/examples/auto_retry/app.py deleted file mode 100644 index 466e76e0..00000000 --- a/examples/auto_retry/app.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Auto-retry example. - -This example demonstrates how stateful functions automatically retry on failure. - -Make sure to follow the setup instructions at -https://docs.dispatch.run/dispatch/stateful-functions/getting-started/ - -Run with: - -uvicorn app:app - -curl http://localhost:8000/ - - -Observe the logs in the terminal where the server is running. It will show a -handful of attempts before succeeding. - -""" - -import random - -import requests -from fastapi import FastAPI - -from dispatch.fastapi import Dispatch - -# Create the FastAPI app like you normally would. -app = FastAPI() - -# chosen by fair dice roll. guaranteed to be random. -rng = random.Random(2) - -# Create a Dispatch instance and pass the FastAPI app to it. It automatically -# sets up the necessary routes and handlers. -dispatch = Dispatch(app) - - -def third_party_api_call(x): - # Simulate a third-party API call that fails. - print(f"Simulating third-party API call with {x}") - if x < 3: - raise requests.RequestException("Simulated failure") - else: - return "SUCCESS" - - -# Use the `dispatch.function` decorator to declare a stateful function. -@dispatch.function -def some_logic(): - print("Executing some logic") - x = rng.randint(0, 5) - result = third_party_api_call(x) - print("RESULT:", result) - - -# This is a normal FastAPI route that handles regular traffic. -@app.get("/") -def root(): - # Use the `dispatch` method to call the stateful function. This call is - # returns immediately after scheduling the function call, which happens in - # the background. - some_logic.dispatch() - # Sending a response now that the HTTP handler has completed. - return "OK" diff --git a/examples/auto_retry/test_app.py b/examples/auto_retry/test_app.py deleted file mode 100644 index 8ce3f188..00000000 --- a/examples/auto_retry/test_app.py +++ /dev/null @@ -1,51 +0,0 @@ -# This file is not part of the example. It is a test file to ensure the example -# works as expected during the CI process. - - -import os -import unittest -from unittest import mock - -from dispatch import Client -from dispatch.sdk.v1 import status_pb2 as status_pb -from dispatch.test import DispatchServer, DispatchService, EndpointClient -from dispatch.test.fastapi import http_client - - -class TestAutoRetry(unittest.TestCase): - @mock.patch.dict( - os.environ, - { - "DISPATCH_ENDPOINT_URL": "http://function-service", - "DISPATCH_API_KEY": "0000000000000000", - }, - ) - def test_app(self): - from .app import app, dispatch - - # Setup a fake Dispatch server. - app_client = http_client(app) - endpoint_client = EndpointClient(app_client) - dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) - with DispatchServer(dispatch_service) as dispatch_server: - # Use it when dispatching function calls. - dispatch.registry.client = Client(api_url=dispatch_server.url) - - response = app_client.get("/") - self.assertEqual(response.status_code, 200) - - dispatch_service.dispatch_calls() - - # Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6 - # calls, including 5 retries. - for i in range(6): - dispatch_service.dispatch_calls() - - self.assertEqual(len(dispatch_service.roundtrips), 1) - roundtrips = list(dispatch_service.roundtrips.values())[0] - self.assertEqual(len(roundtrips), 6) - - statuses = [response.status for request, response in roundtrips] - self.assertEqual( - statuses, [status_pb.STATUS_TEMPORARY_ERROR] * 5 + [status_pb.STATUS_OK] - ) diff --git a/examples/fanout/fanout.py b/examples/fanout.py similarity index 50% rename from examples/fanout/fanout.py rename to examples/fanout.py index 42bb9ef6..ac04459a 100644 --- a/examples/fanout/fanout.py +++ b/examples/fanout.py @@ -1,29 +1,8 @@ -"""Fan-out example using the SDK gather feature - -This example demonstrates how to use the SDK to fan-out multiple requests. - -Run with: - -uvicorn fanout:app - - -You will observe that the get_repo_info calls are executed in parallel. - -""" - +import dispatch import httpx -from fastapi import FastAPI - -from dispatch import gather -from dispatch.fastapi import Dispatch - -app = FastAPI() - -dispatch = Dispatch(app) - @dispatch.function -async def get_repo(repo_owner: str, repo_name: str): +def get_repo(repo_owner: str, repo_name: str): url = f"https://api.github.com/repos/{repo_owner}/{repo_name}" api_response = httpx.get(url) api_response.raise_for_status() @@ -32,7 +11,7 @@ async def get_repo(repo_owner: str, repo_name: str): @dispatch.function -async def get_stargazers(repo_info): +def get_stargazers(repo_info): url = repo_info["stargazers_url"] response = httpx.get(url) response.raise_for_status() @@ -42,7 +21,7 @@ async def get_stargazers(repo_info): @dispatch.function async def reduce_stargazers(repos): - result = await gather(*[get_stargazers(repo) for repo in repos]) + result = await dispatch.gather(*[get_stargazers(repo) for repo in repos]) reduced_stars = set() for repo in result: for stars in repo: @@ -52,18 +31,12 @@ async def reduce_stargazers(repos): @dispatch.function async def fanout(): - # Using gather, we fan-out the four following requests. - repos = await gather( + # Using gather, we fan-out the following requests: + repos = await dispatch.gather( get_repo("dispatchrun", "coroutine"), get_repo("dispatchrun", "dispatch-py"), get_repo("dispatchrun", "wzprof"), ) + return await reduce_stargazers(repos) - stars = await reduce_stargazers(repos) - print("Total stars:", len(stars)) - - -@app.get("/") -def root(): - fanout.dispatch() - return "OK" +print(dispatch.run(fanout())) diff --git a/examples/fanout/__init__.py b/examples/fanout/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/fanout/test_fanout.py b/examples/fanout/test_fanout.py deleted file mode 100644 index 2f11ceda..00000000 --- a/examples/fanout/test_fanout.py +++ /dev/null @@ -1,19 +0,0 @@ -# This file is not part of the example. It is a test file to ensure the example -# works as expected during the CI process. - - -import os -import unittest -from unittest import mock - - -class TestFanout(unittest.TestCase): - @mock.patch.dict( - os.environ, - { - "DISPATCH_ENDPOINT_URL": "http://function-service", - "DISPATCH_API_KEY": "0000000000000000", - }, - ) - def test_app(self): - pass # Skip this test for now diff --git a/examples/getting_started.py b/examples/getting_started.py new file mode 100644 index 00000000..f2ecacea --- /dev/null +++ b/examples/getting_started.py @@ -0,0 +1,16 @@ +import dispatch +import requests + + +# Use the `dispatch.function` decorator declare a stateful function. +@dispatch.function +def publish(url, payload) -> str: + r = requests.post(url, data=payload) + r.raise_for_status() + return r.text + + +# Use the `dispatch.run` function to run the function with automatic error +# handling and retries. +res = dispatch.run(publish("https://httpstat.us/200", {"hello": "world"})) +print(res) diff --git a/examples/getting_started/__init__.py b/examples/getting_started/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/getting_started/app.py b/examples/getting_started/app.py deleted file mode 100644 index 998e0b21..00000000 --- a/examples/getting_started/app.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Getting started example. - -This is the most basic example to get started with Dispatch Functions. - -Follow along with the tutorial at: -https://docs.dispatch.run/dispatch/getting-started - -The program starts a FastAPI server and initializes the Dispatch SDK that -registers one function. This function makes a dummy but durable HTTP request. -The server exposes one route (`/`). This route's handler asynchronously invokes -the durable function. - -# Setup - -## Get a Dispatch API key - -Sign up for Dispatch and generate a new API key: -https://docs.dispatch.run/dispatch/getting-started/production-deployment#creating-a-verification-key - -## Create a local tunnel - -Use ngrok to create a local tunnel to your server. - -1. Download and install ngrok from https://ngrok.com/download -2. Start a new tunnel with `ngrok http http://localhost:8000` - -Note the forwarding address. - -## Install dependencies - -pip install dispatch-py[fastapi] requests uvicorn[standard] - -# Launch the example - -1. Export the environment variables for the public address and the dispatch API - key. For example: - -export DISPATCH_ENDPOINT_URL=https://ab642fb8661e.ngrok.app -export DISPATCH_API_KEY=s56kfDPal9ErVvVxgFGL6YTcLOvchtg5 -export "DISPATCH_VERIFICATION_KEY=`curl -s \ - -d '{}' \ - -H "Authorization: Bearer $DISPATCH_API_KEY" \ - -H "Content-Type: application/json" \ - https://api.dispatch.run/dispatch.v1.SigningKeyService/CreateSigningKey | \ - jq -r .key.asymmetricKey.publicKey`" - -2. Start the server: - -uvicorn app:app - -3. Request the root handler: - -curl http://localhost:8000/ - -""" - -import requests -from fastapi import FastAPI - -from dispatch.fastapi import Dispatch - -# Create the FastAPI app like you normally would. -app = FastAPI() - -# Create a Dispatch instance and pass the FastAPI app to it. It automatically -# sets up the necessary routes and handlers. -dispatch = Dispatch(app) - - -# Use the `dispatch.function` decorator declare a stateful function. -@dispatch.function -def publish(url, payload): - r = requests.post(url, data=payload) - r.raise_for_status() - - -# This is a normal FastAPI route that handles regular traffic. -@app.get("/") -def root(): - # Use the `dispatch` method to call the stateful function. This call is - # returns immediately after scheduling the function call, which happens in - # the background. - publish.dispatch("https://httpstat.us/200", {"hello": "world"}) - # Sending a response now that the HTTP handler has completed. - return "OK" diff --git a/examples/getting_started/test_app.py b/examples/getting_started/test_app.py deleted file mode 100644 index a3345b92..00000000 --- a/examples/getting_started/test_app.py +++ /dev/null @@ -1,40 +0,0 @@ -# This file is not part of the example. It is a test file to ensure the example -# works as expected during the CI process. - - -import os -import unittest -from unittest import mock - -from dispatch import Client -from dispatch.test import DispatchServer, DispatchService, EndpointClient -from dispatch.test.fastapi import http_client - - -class TestGettingStarted(unittest.TestCase): - @mock.patch.dict( - os.environ, - { - "DISPATCH_ENDPOINT_URL": "http://function-service", - "DISPATCH_API_KEY": "0000000000000000", - }, - ) - def test_app(self): - from .app import app, dispatch - - # Setup a fake Dispatch server. - app_client = http_client(app) - endpoint_client = EndpointClient(app_client) - dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) - with DispatchServer(dispatch_service) as dispatch_server: - # Use it when dispatching function calls. - dispatch.registry.client = Client(api_url=dispatch_server.url) - - response = app_client.get("/") - self.assertEqual(response.status_code, 200) - - dispatch_service.dispatch_calls() - - self.assertEqual(len(dispatch_service.roundtrips), 1) # one call submitted - dispatch_id, roundtrips = list(dispatch_service.roundtrips.items())[0] - self.assertEqual(len(roundtrips), 1) # one roundtrip for this call diff --git a/examples/github_stats/app.py b/examples/github_stats.py similarity index 100% rename from examples/github_stats/app.py rename to examples/github_stats.py diff --git a/examples/github_stats/__init__.py b/examples/github_stats/__init__.py deleted file mode 100644 index e69de29b..00000000 From 280355121429b6842e918783f7f793659c6fc027 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 16:13:42 -0700 Subject: [PATCH 20/34] fix emulation of call results (don't end on temporary errors) Signed-off-by: Achille Roussel --- examples/auto_retry.py | 1 + src/dispatch/http.py | 4 +++- src/dispatch/scheduler.py | 10 +++++----- src/dispatch/status.py | 13 +++++++++++++ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/examples/auto_retry.py b/examples/auto_retry.py index 3d6b8bfe..0fad2f82 100644 --- a/examples/auto_retry.py +++ b/examples/auto_retry.py @@ -1,4 +1,5 @@ import dispatch +import dispatch.integrations.requests import random import requests diff --git a/src/dispatch/http.py b/src/dispatch/http.py index b9937d08..fb12ca49 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -398,7 +398,9 @@ async def function_service_run( call_result = CallResult._from_proto(result) call_future = _calls[req.dispatch_id] if call_result.error is not None: - call_future.set_exception(call_result.error.to_exception()) + call_result.error.status = Status(response.status) + if not call_result.error.status.temporary: + call_future.set_exception(call_result.error.to_exception()) else: call_future.set_result(call_result.output) if result.HasField("output"): diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 112428f0..9140c57a 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -490,15 +490,15 @@ async def _run(self, input: Input) -> Output: state.suspended = {} -async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): +async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]) -> Optional[Output]: return await make_coroutine(state, coroutine, pending_calls) @coroutine def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]): + coroutine_result: Optional[CoroutineResult] = None + while True: - coroutine_yield = None - coroutine_result: Optional[CoroutineResult] = None try: coroutine_yield = coroutine.run() except TailCall as tc: @@ -535,7 +535,7 @@ def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call] def set_coroutine_result( state: State, coroutine: Coroutine, coroutine_result: CoroutineResult -): +) -> Optional[Output]: if coroutine_result.call is not None: logger.debug("%s reset to %s", coroutine, coroutine_result.call.function) elif coroutine_result.error is not None: @@ -568,7 +568,7 @@ def set_coroutine_result( state.ready.insert(0, parent) del state.suspended[parent.id] logger.debug("parent %s is now ready", parent) - return + return None def set_coroutine_call( diff --git a/src/dispatch/status.py b/src/dispatch/status.py index 7454f80b..fd835ae1 100644 --- a/src/dispatch/status.py +++ b/src/dispatch/status.py @@ -35,6 +35,19 @@ def __repr__(self): def __str__(self): return self.name + # TODO: remove, this is only used for the emulated wait of call results + @property + def temporary(self) -> bool: + return self in { + Status.TIMEOUT, + Status.THROTTLED, + Status.TEMPORARY_ERROR, + Status.INCOMPATIBLE_STATE, + Status.DNS_ERROR, + Status.TCP_ERROR, + Status.TLS_ERROR, + Status.HTTP_ERROR, + } # Maybe we should find a better way to define that enum. It's that way to please # Mypy and provide documentation for the enum values. From 0c186a36f70e12660bafb351909d852b2f708c40 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 16:22:14 -0700 Subject: [PATCH 21/34] handle client connection errors Signed-off-by: Achille Roussel --- src/dispatch/http.py | 7 +++++-- src/dispatch/status.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/dispatch/http.py b/src/dispatch/http.py index fb12ca49..099163af 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -16,7 +16,7 @@ overload, ) -from aiohttp import web +from aiohttp import web, ClientConnectionError from http_message_signatures import InvalidSignature from typing_extensions import ParamSpec, TypeAlias @@ -30,7 +30,10 @@ parse_verification_key, verify_request, ) -from dispatch.status import Status +from dispatch.status import Status, register_error_type + +# https://docs.aiohttp.org/en/stable/client_reference.html +register_error_type(ClientConnectionError, Status.TCP_ERROR) logger = logging.getLogger(__name__) diff --git a/src/dispatch/status.py b/src/dispatch/status.py index fd835ae1..82f90d27 100644 --- a/src/dispatch/status.py +++ b/src/dispatch/status.py @@ -1,4 +1,5 @@ import enum +import ssl from typing import Any, Callable, Dict, Type, Union from dispatch.sdk.v1 import status_pb2 as status_pb @@ -129,6 +130,8 @@ def status_for_error(error: BaseException) -> Status: # tend to be caused by invalid use of syscalls, which are # unlikely at higher abstraction levels. return Status.TEMPORARY_ERROR + elif isinstance(error, ssl.SSLError) or isinstance(error, ssl.CertificateError): + return Status.TLS_ERROR return Status.PERMANENT_ERROR From 6799ab3d6730548e3bcb765da3cca08e5c46ecd0 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 16:23:12 -0700 Subject: [PATCH 22/34] fix formatting Signed-off-by: Achille Roussel --- examples/auto_retry.py | 7 +++++-- examples/fanout.py | 5 ++++- examples/getting_started.py | 3 ++- examples/github_stats.py | 4 +++- src/dispatch/function.py | 1 + src/dispatch/http.py | 8 +++++--- src/dispatch/scheduler.py | 4 +++- src/dispatch/status.py | 1 + src/dispatch/test/__init__.py | 2 ++ 9 files changed, 26 insertions(+), 9 deletions(-) diff --git a/examples/auto_retry.py b/examples/auto_retry.py index 0fad2f82..069a3a5b 100644 --- a/examples/auto_retry.py +++ b/examples/auto_retry.py @@ -1,10 +1,13 @@ -import dispatch -import dispatch.integrations.requests import random + import requests +import dispatch +import dispatch.integrations.requests + rng = random.Random(2) + def third_party_api_call(x: int) -> str: # Simulate a third-party API call that fails. print(f"Simulating third-party API call with {x}") diff --git a/examples/fanout.py b/examples/fanout.py index ac04459a..856d496e 100644 --- a/examples/fanout.py +++ b/examples/fanout.py @@ -1,6 +1,8 @@ -import dispatch import httpx +import dispatch + + @dispatch.function def get_repo(repo_owner: str, repo_name: str): url = f"https://api.github.com/repos/{repo_owner}/{repo_name}" @@ -39,4 +41,5 @@ async def fanout(): ) return await reduce_stargazers(repos) + print(dispatch.run(fanout())) diff --git a/examples/getting_started.py b/examples/getting_started.py index f2ecacea..be1b41f4 100644 --- a/examples/getting_started.py +++ b/examples/getting_started.py @@ -1,6 +1,7 @@ -import dispatch import requests +import dispatch + # Use the `dispatch.function` decorator declare a stateful function. @dispatch.function diff --git a/examples/github_stats.py b/examples/github_stats.py index 996743d1..2edc6177 100644 --- a/examples/github_stats.py +++ b/examples/github_stats.py @@ -14,10 +14,12 @@ """ -import dispatch import httpx + +import dispatch from dispatch.error import ThrottleError + def get_gh_api(url): print(f"GET {url}") response = httpx.get(url) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 98442d0e..4b53704d 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -386,6 +386,7 @@ def set_default_registry(reg: Registry): # for results. _calls: Dict[str, asyncio.Future] = {} + class Client: """Client for the Dispatch API.""" diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 099163af..55778bdc 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -16,11 +16,11 @@ overload, ) -from aiohttp import web, ClientConnectionError +from aiohttp import ClientConnectionError, web from http_message_signatures import InvalidSignature from typing_extensions import ParamSpec, TypeAlias -from dispatch.function import Batch, Function, Registry, default_registry, _calls +from dispatch.function import Batch, Function, Registry, _calls, default_registry from dispatch.proto import CallResult, Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -81,7 +81,9 @@ def function(self, func): def batch(self) -> Batch: return self.registry.batch() - async def run(self, url: str, method: str, headers: Mapping[str, str], data: bytes) -> bytes: + async def run( + self, url: str, method: str, headers: Mapping[str, str], data: bytes + ) -> bytes: return await function_service_run( url, method, diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 9140c57a..a5ac9fe0 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -490,7 +490,9 @@ async def _run(self, input: Input) -> Output: state.suspended = {} -async def run_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call]) -> Optional[Output]: +async def run_coroutine( + state: State, coroutine: Coroutine, pending_calls: List[Call] +) -> Optional[Output]: return await make_coroutine(state, coroutine, pending_calls) diff --git a/src/dispatch/status.py b/src/dispatch/status.py index 82f90d27..7563132d 100644 --- a/src/dispatch/status.py +++ b/src/dispatch/status.py @@ -50,6 +50,7 @@ def temporary(self) -> bool: Status.HTTP_ERROR, } + # Maybe we should find a better way to define that enum. It's that way to please # Mypy and provide documentation for the enum values. diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 1b93ff47..dd69a9e4 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -60,6 +60,7 @@ _dispatch_ids = (str(i) for i in range(2**32 - 1)) + class Client(BaseClient): def session(self) -> aiohttp.ClientSession: # Use an individual sessionn in the test client instead of the default @@ -76,6 +77,7 @@ def __init__(self, app: web.Application): def url(self): return f"http://{self.host}:{self.port}" + class Service(web.Application): tasks: Dict[str, asyncio.Task] _session: Optional[aiohttp.ClientSession] = None From a406d2a1a47a5e83d415a21209dea3b1f1d1a6a7 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 16:25:36 -0700 Subject: [PATCH 23/34] remove type annotations from examples so they are compatible with Python 3.8 Signed-off-by: Achille Roussel --- examples/auto_retry.py | 4 ++-- examples/getting_started.py | 2 +- examples/github_stats.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/auto_retry.py b/examples/auto_retry.py index 069a3a5b..5eff8bd7 100644 --- a/examples/auto_retry.py +++ b/examples/auto_retry.py @@ -8,7 +8,7 @@ rng = random.Random(2) -def third_party_api_call(x: int) -> str: +def third_party_api_call(x): # Simulate a third-party API call that fails. print(f"Simulating third-party API call with {x}") if x < 3: @@ -19,7 +19,7 @@ def third_party_api_call(x: int) -> str: # Use the `dispatch.function` decorator to declare a stateful function. @dispatch.function -def application() -> str: +def application(): x = rng.randint(0, 5) return third_party_api_call(x) diff --git a/examples/getting_started.py b/examples/getting_started.py index be1b41f4..38ad9cfb 100644 --- a/examples/getting_started.py +++ b/examples/getting_started.py @@ -5,7 +5,7 @@ # Use the `dispatch.function` decorator declare a stateful function. @dispatch.function -def publish(url, payload) -> str: +def publish(url, payload): r = requests.post(url, data=payload) r.raise_for_status() return r.text diff --git a/examples/github_stats.py b/examples/github_stats.py index 2edc6177..c1dc6db5 100644 --- a/examples/github_stats.py +++ b/examples/github_stats.py @@ -31,21 +31,21 @@ def get_gh_api(url): @dispatch.function -async def get_repo_info(repo_owner: str, repo_name: str) -> dict: +async def get_repo_info(repo_owner, repo_name): url = f"https://api.github.com/repos/{repo_owner}/{repo_name}" repo_info = get_gh_api(url) return repo_info @dispatch.function -async def get_contributors(repo_info: dict) -> list[dict]: +async def get_contributors(repo_info): url = repo_info["contributors_url"] contributors = get_gh_api(url) return contributors @dispatch.function -async def main() -> list[dict]: +async def main(): repo_info = await get_repo_info("dispatchrun", "coroutine") print( f"""Repository: {repo_info['full_name']} From d5f58483205d4e9dcc86862c854a482584eb1852 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 17:48:19 -0700 Subject: [PATCH 24/34] use context varaibles instead of thread locals Signed-off-by: Achille Roussel --- src/dispatch/scheduler.py | 19 ++++--------- src/dispatch/test/__init__.py | 53 +++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index a5ac9fe0..b0cbb093 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -1,8 +1,8 @@ import asyncio +import contextvars import logging import pickle import sys -import threading from dataclasses import dataclass, field from types import coroutine from typing import ( @@ -32,19 +32,10 @@ CoroutineID: TypeAlias = int CorrelationID: TypeAlias = int - -class ThreadLocal(threading.local): - in_function_call: bool - - def __init__(self): - self.in_function_call = False - - -thread_local = ThreadLocal() - +_in_function_call = contextvars.ContextVar("dispatch.scheduler.in_function_call", default=False) def in_function_call() -> bool: - return thread_local.in_function_call + return bool(_in_function_call.get()) @dataclass @@ -343,7 +334,7 @@ def __init__( async def run(self, input: Input) -> Output: try: - thread_local.in_function_call = True + token = _in_function_call.set(True) return await self._run(input) except Exception as e: logger.exception( @@ -351,7 +342,7 @@ async def run(self, input: Input) -> Output: ) return Output.error(Error.from_exception(e)) finally: - thread_local.in_function_call = False + _in_function_call.reset(token) def _init_state(self, input: Input) -> State: logger.debug("starting main coroutine") diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index dd69a9e4..66502349 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -4,7 +4,7 @@ import unittest from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar +from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar import aiohttp from aiohttp import web @@ -193,6 +193,7 @@ def make_request(call: Call) -> RunRequest: else: error = Error(type="status", message=str(res.status)) return CallResult( + correlation_id=call.correlation_id, dispatch_id=dispatch_id, error=error, ) @@ -203,6 +204,7 @@ def make_request(call: Call) -> RunRequest: continue result = res.exit.result return CallResult( + correlation_id=call.correlation_id, dispatch_id=dispatch_id, output=result.output if result.HasField("output") else None, error=result.error if result.HasField("error") else None, @@ -317,6 +319,7 @@ def test(self): endpoint=DISPATCH_ENDPOINT_URL, client=Client(api_key=DISPATCH_API_KEY, api_url=DISPATCH_API_URL), ) +set_default_registry(_registry) @_registry.function @@ -354,7 +357,33 @@ async def broken_nested(name: str) -> str: return await broken() -set_default_registry(_registry) +@_registry.function +async def distributed_merge_sort(values: List[int]) -> List[int]: + if len(values) <= 1: + return values + i = len(values) // 2 + + (l, r) = await dispatch.gather( + distributed_merge_sort(values[:i]), + distributed_merge_sort(values[i:]), + ) + + return merge(l, r) + + +def merge(l: List[int], r: List[int]) -> List[int]: + result = [] + i = j = 0 + while i < len(l) and j < len(r): + if l[i] < r[j]: + result.append(l[i]) + i += 1 + else: + result.append(r[j]) + j += 1 + result.extend(l[i:]) + result.extend(r[j:]) + return result class TestCase(unittest.TestCase): @@ -473,6 +502,26 @@ async def test_call_nested_function_with_error(self): with self.assertRaises(ValueError) as e: await broken_nested("hello") + @aiotest + async def test_distributed_merge_sort_no_values(self): + values: List[int] = [] + self.assertEqual(await distributed_merge_sort(values), sorted(values)) + + @aiotest + async def test_distributed_merge_sort_one_value(self): + values: List[int] = [1] + self.assertEqual(await distributed_merge_sort(values), sorted(values)) + + @aiotest + async def test_distributed_merge_sort_two_values(self): + values: List[int] = [1, 5] + self.assertEqual(await distributed_merge_sort(values), sorted(values)) + + @aiotest + async def test_distributed_merge_sort_many_values(self): + values: List[int] = [1, 5, 3, 2, 4, 6, 7, 8, 9, 0] + self.assertEqual(await distributed_merge_sort(values), sorted(values)) + class ClientRequestContentLengthMissing(aiohttp.ClientRequest): def update_headers(self, skip_auto_headers): From 091891874ae3ecd1e5483c123e2eddf3c91a2f68 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 17:49:11 -0700 Subject: [PATCH 25/34] update README Signed-off-by: Achille Roussel --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 56adefb1..9a726d8b 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ import dispatch def greet(msg: str): print(f"Hello, ${msg}!") -dispatch.run(lambda: greet.dispatch('World')) +dispatch.run(greet('World')) ``` Obviously, this is just an example, a real application would perform much more From d0162218c932a68418785a68270eb855dc4f48e6 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 21:44:31 -0700 Subject: [PATCH 26/34] fix running examples as tests Signed-off-by: Achille Roussel --- examples/auto_retry.py | 6 ++++-- examples/fanout.py | 3 ++- examples/getting_started.py | 12 +++++++----- examples/github_stats.py | 24 ++++-------------------- examples/test_examples.py | 29 +++++++++++++++++++++++++++++ src/dispatch/scheduler.py | 11 +++++++++-- src/dispatch/test/__init__.py | 30 ++++++++++++++++++++++++++++-- 7 files changed, 83 insertions(+), 32 deletions(-) create mode 100644 examples/test_examples.py diff --git a/examples/auto_retry.py b/examples/auto_retry.py index 5eff8bd7..fe932c92 100644 --- a/examples/auto_retry.py +++ b/examples/auto_retry.py @@ -12,6 +12,7 @@ def third_party_api_call(x): # Simulate a third-party API call that fails. print(f"Simulating third-party API call with {x}") if x < 3: + print("RAISE EXCEPTION") raise requests.RequestException("Simulated failure") else: return "SUCCESS" @@ -19,9 +20,10 @@ def third_party_api_call(x): # Use the `dispatch.function` decorator to declare a stateful function. @dispatch.function -def application(): +def auto_retry(): x = rng.randint(0, 5) return third_party_api_call(x) -dispatch.run(application()) +if __name__ == "__main__": + print(dispatch.run(auto_retry())) diff --git a/examples/fanout.py b/examples/fanout.py index 856d496e..883b31ad 100644 --- a/examples/fanout.py +++ b/examples/fanout.py @@ -42,4 +42,5 @@ async def fanout(): return await reduce_stargazers(repos) -print(dispatch.run(fanout())) +if __name__ == "__main__": + print(dispatch.run(fanout())) diff --git a/examples/getting_started.py b/examples/getting_started.py index 38ad9cfb..17ff7f8e 100644 --- a/examples/getting_started.py +++ b/examples/getting_started.py @@ -3,7 +3,6 @@ import dispatch -# Use the `dispatch.function` decorator declare a stateful function. @dispatch.function def publish(url, payload): r = requests.post(url, data=payload) @@ -11,7 +10,10 @@ def publish(url, payload): return r.text -# Use the `dispatch.run` function to run the function with automatic error -# handling and retries. -res = dispatch.run(publish("https://httpstat.us/200", {"hello": "world"})) -print(res) +@dispatch.function +async def getting_started(): + return await publish("https://httpstat.us/200", {"hello": "world"}) + + +if __name__ == "__main__": + print(dispatch.run(getting_started())) diff --git a/examples/github_stats.py b/examples/github_stats.py index c1dc6db5..0636882d 100644 --- a/examples/github_stats.py +++ b/examples/github_stats.py @@ -1,19 +1,3 @@ -"""Github repository stats example. - -This example demonstrates how to use async functions orchestrated by Dispatch. - -Make sure to follow the setup instructions at -https://docs.dispatch.run/dispatch/stateful-functions/getting-started/ - -Run with: - -uvicorn app:app - - -Logs will show a pipeline of functions being called and their results. - -""" - import httpx import dispatch @@ -31,21 +15,21 @@ def get_gh_api(url): @dispatch.function -async def get_repo_info(repo_owner, repo_name): +def get_repo_info(repo_owner, repo_name): url = f"https://api.github.com/repos/{repo_owner}/{repo_name}" repo_info = get_gh_api(url) return repo_info @dispatch.function -async def get_contributors(repo_info): +def get_contributors(repo_info): url = repo_info["contributors_url"] contributors = get_gh_api(url) return contributors @dispatch.function -async def main(): +async def github_stats(): repo_info = await get_repo_info("dispatchrun", "coroutine") print( f"""Repository: {repo_info['full_name']} @@ -57,5 +41,5 @@ async def main(): if __name__ == "__main__": - contributors = dispatch.run(main()) + contributors = dispatch.run(github_stats()) print(f"Contributors: {len(contributors)}") diff --git a/examples/test_examples.py b/examples/test_examples.py new file mode 100644 index 00000000..4d63a517 --- /dev/null +++ b/examples/test_examples.py @@ -0,0 +1,29 @@ +import dispatch.test + +from .auto_retry import auto_retry +from .fanout import fanout +from .getting_started import getting_started +from .github_stats import github_stats + + +@dispatch.test.function +async def test_auto_retry(): + assert await auto_retry() == "SUCCESS" + + +@dispatch.test.function +async def test_fanout(): + contributors = await fanout() + assert len(contributors) >= 15 + assert "achille-roussel" in contributors + + +@dispatch.test.function +async def test_getting_started(): + assert await getting_started() == "200 OK" + + +@dispatch.test.function +async def test_github_stats(): + contributors = await github_stats() + assert len(contributors) >= 6 diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index b0cbb093..c6dfdc5c 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -32,7 +32,10 @@ CoroutineID: TypeAlias = int CorrelationID: TypeAlias = int -_in_function_call = contextvars.ContextVar("dispatch.scheduler.in_function_call", default=False) +_in_function_call = contextvars.ContextVar( + "dispatch.scheduler.in_function_call", default=False +) + def in_function_call() -> bool: return bool(_in_function_call.get()) @@ -523,7 +526,11 @@ def make_coroutine(state: State, coroutine: Coroutine, pending_calls: List[Call] if isinstance(coroutine_yield, RaceDirective): return set_coroutine_race(state, coroutine, coroutine_yield.awaitables) - yield coroutine_yield + try: + yield coroutine_yield + except Exception as e: + coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e) + return set_coroutine_result(state, coroutine, coroutine_result) def set_coroutine_result( diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 66502349..7250c96d 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -28,7 +28,17 @@ from dispatch.sdk.v1.error_pb2 import Error from dispatch.sdk.v1.function_pb2 import RunRequest, RunResponse from dispatch.sdk.v1.poll_pb2 import PollResult -from dispatch.sdk.v1.status_pb2 import STATUS_OK +from dispatch.sdk.v1.status_pb2 import ( + STATUS_DNS_ERROR, + STATUS_HTTP_ERROR, + STATUS_INCOMPATIBLE_STATE, + STATUS_OK, + STATUS_TCP_ERROR, + STATUS_TEMPORARY_ERROR, + STATUS_THROTTLED, + STATUS_TIMEOUT, + STATUS_TLS_ERROR, +) from .client import EndpointClient from .server import DispatchServer @@ -183,7 +193,18 @@ def make_request(call: Call) -> RunRequest: res = await self.run(call.endpoint, req) if res.status != STATUS_OK: - # TODO: emulate retries etc... + if res.status in ( + STATUS_TIMEOUT, + STATUS_THROTTLED, + STATUS_TEMPORARY_ERROR, + STATUS_INCOMPATIBLE_STATE, + STATUS_DNS_ERROR, + STATUS_TCP_ERROR, + STATUS_TLS_ERROR, + STATUS_HTTP_ERROR, + ): + continue # emulate retries, without backoff for now + if ( res.HasField("exit") and res.exit.HasField("result") @@ -263,14 +284,19 @@ async def main(coro: Coroutine[Any, Any, None]) -> None: api = Service() app = Dispatch(reg) try: + print("Starting bakend") async with Server(api) as backend: + print("Starting server") async with Server(app) as server: # Here we break through the abstraction layers a bit, it's not # ideal but it works for now. reg.client.api_url.value = backend.url reg.endpoint = server.url + print("BACKEND:", backend.url) + print("SERVER:", server.url) await coro finally: + print("DONE!") await api.close() # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. From cf363d063c4b82c2960a8920b6cf5142ef2f2bf2 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Mon, 17 Jun 2024 22:06:19 -0700 Subject: [PATCH 27/34] switch pytest.mark.asyncio position because maybe that's why the tests are skipped on 3.8/3.9 Signed-off-by: Achille Roussel --- tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index 70e754d2..7c4d4224 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -34,8 +34,8 @@ def test_can_be_constructed_on_https(): Client(api_url="https://example.com", api_key="foo") -@mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "0000000000000000"}) @pytest.mark.asyncio +@mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "0000000000000000"}) async def test_api_key_from_env(): async with server() as api: client = Client(api_url=api.url) From debe5d81bafd31e060ce54a81f81b5eb2b57a5a4 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 10:28:50 -0700 Subject: [PATCH 28/34] fix tests for Python 3.9 Signed-off-by: Achille Roussel --- tests/test_client.py | 31 +++++++++++++++++++++---------- tests/test_fastapi.py | 2 +- tests/test_http.py | 2 +- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 7c4d4224..e6fa5abf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -34,18 +34,29 @@ def test_can_be_constructed_on_https(): Client(api_url="https://example.com", api_key="foo") +# On Python 3.8/3.9, pytest.mark.asyncio doesn't work with mock.patch.dict, +# so we have to use the old-fashioned way of setting the environment variable +# and then cleaning it up manually. +# +# @mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "0000000000000000"}) @pytest.mark.asyncio -@mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "0000000000000000"}) async def test_api_key_from_env(): - async with server() as api: - client = Client(api_url=api.url) - - with pytest.raises( - PermissionError, - match=r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", - ) as mc: - await client.dispatch([Call(function="my-function", input=42)]) - + prev_api_key = os.environ.get("DISPATCH_API_KEY") + try: + os.environ["DISPATCH_API_KEY"] = "0000000000000000" + async with server() as api: + client = Client(api_url=api.url) + + with pytest.raises( + PermissionError, + match=r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", + ) as mc: + await client.dispatch([Call(function="my-function", input=42)]) + finally: + if prev_api_key is None: + del os.environ["DISPATCH_API_KEY"] + else: + os.environ["DISPATCH_API_KEY"] = prev_api_key @pytest.mark.asyncio async def test_api_key_from_arg(): diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 9c067ef4..a317a9be 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -57,7 +57,7 @@ def dispatch_test_init(self, reg: Registry) -> str: self.sockets = [sock] self.uvicorn = uvicorn.Server(config) self.runner = Runner() - if sys.version_info >= (3, 9): + if sys.version_info >= (3, 10): self.event = asyncio.Event() else: self.event = asyncio.Event(loop=self.runner.get_loop()) diff --git a/tests/test_http.py b/tests/test_http.py index 69ad3654..cba4742b 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -46,7 +46,7 @@ def dispatch_test_init(self, reg: Registry) -> str: self.aiohttp = Server(host, port, Dispatch(reg)) self.aioloop.run(self.aiohttp.start()) - if sys.version_info >= (3, 9): + if sys.version_info >= (3, 10): self.aiowait = asyncio.Event() else: self.aiowait = asyncio.Event(loop=self.aioloop.get_loop()) From 848715988ca79a8466ba031d5cc1cfaa79d28aa8 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 10:36:05 -0700 Subject: [PATCH 29/34] fix formatting Signed-off-by: Achille Roussel --- tests/test_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index e6fa5abf..871de264 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -48,8 +48,8 @@ async def test_api_key_from_env(): client = Client(api_url=api.url) with pytest.raises( - PermissionError, - match=r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", + PermissionError, + match=r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", ) as mc: await client.dispatch([Call(function="my-function", input=42)]) finally: @@ -58,6 +58,7 @@ async def test_api_key_from_env(): else: os.environ["DISPATCH_API_KEY"] = prev_api_key + @pytest.mark.asyncio async def test_api_key_from_arg(): async with server() as api: From 672b1a122def05c53fcb4038572b5cf5b55448eb Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 11:56:33 -0700 Subject: [PATCH 30/34] re-enable the batch API Signed-off-by: Achille Roussel --- src/dispatch/function.py | 17 +- tests/test_fastapi.py | 360 --------------------------------------- 2 files changed, 8 insertions(+), 369 deletions(-) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 4b53704d..dc8a54aa 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -330,11 +330,10 @@ def _register(self, name: str, wrapped_func: PrimitiveFunction): raise ValueError(f"function already registered with name '{name}'") self.functions[name] = wrapped_func - def batch(self): # -> Batch: + def batch(self) -> Batch: """Returns a Batch instance that can be used to build a set of calls to dispatch.""" - # return self.client.batch() - raise NotImplemented + return self.client.batch() _registries: Dict[str, Registry] = {} @@ -565,7 +564,11 @@ def add_call(self, call: Call) -> Batch: self.calls.append(call) return self - def dispatch(self) -> List[DispatchID]: + def clear(self): + """Reset the batch.""" + self.calls = [] + + async def dispatch(self) -> List[DispatchID]: """Dispatch dispatches the calls asynchronously. The batch is reset when the calls are dispatched successfully. @@ -576,10 +579,6 @@ def dispatch(self) -> List[DispatchID]: """ if not self.calls: return [] - dispatch_ids = asyncio.run(self.client.dispatch(self.calls)) + dispatch_ids = await self.client.dispatch(self.calls) self.clear() return dispatch_ids - - def clear(self): - """Reset the batch.""" - self.calls = [] diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index a317a9be..be4964aa 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -100,363 +100,3 @@ def create_endpoint_client( def response_output(resp: function_pb.RunResponse) -> Any: return any_unpickle(resp.exit.result.output) - -# class TestCoroutine(unittest.TestCase): -# def setUp(self): -# clear_functions() - -# self.app = fastapi.FastAPI() - -# @self.app.get("/") -# def root(): -# return "OK" - -# self.dispatch = create_dispatch_instance( -# self.app, endpoint="https://127.0.0.1:9999" -# ) -# self.http_client = TestClient(self.app) -# self.client = create_endpoint_client(self.app) - -# def tearDown(self): -# self.dispatch.registry.close() - -# def execute( -# self, func: Function, input=None, state=None, calls=None -# ) -> function_pb.RunResponse: -# """Test helper to invoke coroutines on the local server.""" -# req = function_pb.RunRequest(function=func.name) - -# if input is not None: -# input_bytes = pickle.dumps(input) -# input_any = google.protobuf.any_pb2.Any() -# input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes)) -# req.input.CopyFrom(input_any) -# if state is not None: -# req.poll_result.coroutine_state = state -# if calls is not None: -# for c in calls: -# req.poll_result.results.append(c) - -# resp = self.client.run(req) -# self.assertIsInstance(resp, function_pb.RunResponse) -# return resp - -# def call(self, func: Function, *args, **kwargs) -> function_pb.RunResponse: -# return self.execute(func, input=Arguments(args, kwargs)) - -# def proto_call(self, call: call_pb.Call) -> call_pb.CallResult: -# req = function_pb.RunRequest( -# function=call.function, -# input=call.input, -# ) -# resp = self.client.run(req) -# self.assertIsInstance(resp, function_pb.RunResponse) - -# # Assert the response is terminal. Good enough until the test client can -# # orchestrate coroutines. -# self.assertTrue(len(resp.poll.coroutine_state) == 0) - -# resp.exit.result.correlation_id = call.correlation_id -# return resp.exit.result - -# def test_no_input(self): -# @self.dispatch.primitive_function -# async def my_function(input: Input) -> Output: -# return Output.value("Hello World!") - -# resp = self.execute(my_function) - -# out = response_output(resp) -# self.assertEqual(out, "Hello World!") - -# def test_missing_coroutine(self): -# req = function_pb.RunRequest( -# function="does-not-exist", -# ) - -# with self.assertRaises(httpx.HTTPStatusError) as cm: -# self.client.run(req) -# self.assertEqual(cm.exception.response.status_code, 404) - -# def test_string_input(self): -# @self.dispatch.primitive_function -# async def my_function(input: Input) -> Output: -# return Output.value(f"You sent '{input.input}'") - -# resp = self.execute(my_function, input="cool stuff") -# out = response_output(resp) -# self.assertEqual(out, "You sent 'cool stuff'") - -# def test_error_on_access_state_in_first_call(self): -# @self.dispatch.primitive_function -# async def my_function(input: Input) -> Output: -# try: -# print(input.coroutine_state) -# except ValueError: -# return Output.error( -# Error.from_exception( -# ValueError("This input is for a first function call") -# ) -# ) -# return Output.value("not reached") - -# resp = self.execute(my_function, input="cool stuff") -# self.assertEqual("ValueError", resp.exit.result.error.type) -# self.assertEqual( -# "This input is for a first function call", resp.exit.result.error.message -# ) - -# def test_error_on_access_input_in_second_call(self): -# @self.dispatch.primitive_function -# async def my_function(input: Input) -> Output: -# if input.is_first_call: -# return Output.poll(coroutine_state=b"42") -# try: -# print(input.input) -# except ValueError: -# return Output.error( -# Error.from_exception( -# ValueError("This input is for a resumed coroutine") -# ) -# ) -# return Output.value("not reached") - -# resp = self.execute(my_function, input="cool stuff") -# self.assertEqual(b"42", resp.poll.coroutine_state) - -# resp = self.execute(my_function, state=resp.poll.coroutine_state) -# self.assertEqual("ValueError", resp.exit.result.error.type) -# self.assertEqual( -# "This input is for a resumed coroutine", resp.exit.result.error.message -# ) - -# def test_duplicate_coro(self): -# @self.dispatch.primitive_function -# async def my_function(input: Input) -> Output: -# return Output.value("Do one thing") - -# with self.assertRaises(ValueError): - -# @self.dispatch.primitive_function -# async def my_function(input: Input) -> Output: -# return Output.value("Do something else") - -# def test_two_simple_coroutines(self): -# @self.dispatch.primitive_function -# async def echoroutine(input: Input) -> Output: -# return Output.value(f"Echo: '{input.input}'") - -# @self.dispatch.primitive_function -# async def len_coroutine(input: Input) -> Output: -# return Output.value(f"Length: {len(input.input)}") - -# data = "cool stuff" -# resp = self.execute(echoroutine, input=data) -# out = response_output(resp) -# self.assertEqual(out, "Echo: 'cool stuff'") - -# resp = self.execute(len_coroutine, input=data) -# out = response_output(resp) -# self.assertEqual(out, "Length: 10") - -# def test_coroutine_with_state(self): -# @self.dispatch.primitive_function -# async def coroutine3(input: Input) -> Output: -# if input.is_first_call: -# counter = input.input -# else: -# (counter,) = struct.unpack("@i", input.coroutine_state) -# counter -= 1 -# if counter <= 0: -# return Output.value("done") -# coroutine_state = struct.pack("@i", counter) -# return Output.poll(coroutine_state=coroutine_state) - -# # first call -# resp = self.execute(coroutine3, input=4) -# state = resp.poll.coroutine_state -# self.assertTrue(len(state) > 0) - -# # resume, state = 3 -# resp = self.execute(coroutine3, state=state) -# state = resp.poll.coroutine_state -# self.assertTrue(len(state) > 0) - -# # resume, state = 2 -# resp = self.execute(coroutine3, state=state) -# state = resp.poll.coroutine_state -# self.assertTrue(len(state) > 0) - -# # resume, state = 1 -# resp = self.execute(coroutine3, state=state) -# state = resp.poll.coroutine_state -# self.assertTrue(len(state) == 0) -# out = response_output(resp) -# self.assertEqual(out, "done") - -# def test_coroutine_poll(self): -# @self.dispatch.primitive_function -# async def coro_compute_len(input: Input) -> Output: -# return Output.value(len(input.input)) - -# @self.dispatch.primitive_function -# async def coroutine_main(input: Input) -> Output: -# if input.is_first_call: -# text: str = input.input -# return Output.poll( -# coroutine_state=text.encode(), -# calls=[coro_compute_len._build_primitive_call(text)], -# ) -# text = input.coroutine_state.decode() -# length = input.call_results[0].output -# return Output.value(f"length={length} text='{text}'") - -# resp = self.execute(coroutine_main, input="cool stuff") - -# # main saved some state -# state = resp.poll.coroutine_state -# self.assertTrue(len(state) > 0) -# # main asks for 1 call to compute_len -# self.assertEqual(len(resp.poll.calls), 1) -# call = resp.poll.calls[0] -# self.assertEqual(call.function, coro_compute_len.name) -# self.assertEqual(any_unpickle(call.input), "cool stuff") - -# # make the requested compute_len -# resp2 = self.proto_call(call) -# # check the result is the terminal expected response -# len_resp = any_unpickle(resp2.output) -# self.assertEqual(10, len_resp) - -# # resume main with the result -# resp = self.execute(coroutine_main, state=state, calls=[resp2]) -# # validate the final result -# self.assertTrue(len(resp.poll.coroutine_state) == 0) -# out = response_output(resp) -# self.assertEqual("length=10 text='cool stuff'", out) - -# def test_coroutine_poll_error(self): -# @self.dispatch.primitive_function -# async def coro_compute_len(input: Input) -> Output: -# return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead")) - -# @self.dispatch.primitive_function -# async def coroutine_main(input: Input) -> Output: -# if input.is_first_call: -# text: str = input.input -# return Output.poll( -# coroutine_state=text.encode(), -# calls=[coro_compute_len._build_primitive_call(text)], -# ) -# error = input.call_results[0].error -# if error is not None: -# return Output.value(f"msg={error.message} type='{error.type}'") -# else: -# raise RuntimeError(f"unexpected call results: {input.call_results}") - -# resp = self.execute(coroutine_main, input="cool stuff") - -# # main saved some state -# state = resp.poll.coroutine_state -# self.assertTrue(len(state) > 0) -# # main asks for 1 call to compute_len -# self.assertEqual(len(resp.poll.calls), 1) -# call = resp.poll.calls[0] -# self.assertEqual(call.function, coro_compute_len.name) -# self.assertEqual(any_unpickle(call.input), "cool stuff") - -# # make the requested compute_len -# resp2 = self.proto_call(call) - -# # resume main with the result -# resp = self.execute(coroutine_main, state=state, calls=[resp2]) -# # validate the final result -# self.assertTrue(len(resp.poll.coroutine_state) == 0) -# out = response_output(resp) -# self.assertEqual(out, "msg=Dead type='type'") - -# def test_coroutine_error(self): -# @self.dispatch.primitive_function -# async def mycoro(input: Input) -> Output: -# return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead")) - -# resp = self.execute(mycoro) -# self.assertEqual("sometype", resp.exit.result.error.type) -# self.assertEqual("dead", resp.exit.result.error.message) - -# def test_coroutine_expected_exception(self): -# @self.dispatch.primitive_function -# async def mycoro(input: Input) -> Output: -# try: -# 1 / 0 -# except ZeroDivisionError as e: -# return Output.error(Error.from_exception(e)) -# self.fail("should not reach here") - -# resp = self.execute(mycoro) -# self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) -# self.assertEqual("division by zero", resp.exit.result.error.message) -# self.assertEqual(Status.PERMANENT_ERROR, resp.status) - -# def test_coroutine_unexpected_exception(self): -# @self.dispatch.function -# def mycoro(): -# 1 / 0 -# self.fail("should not reach here") - -# resp = self.call(mycoro) -# self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) -# self.assertEqual("division by zero", resp.exit.result.error.message) -# self.assertEqual(Status.PERMANENT_ERROR, resp.status) - -# def test_specific_status(self): -# @self.dispatch.primitive_function -# async def mycoro(input: Input) -> Output: -# return Output.error(Error(Status.THROTTLED, "foo", "bar")) - -# resp = self.execute(mycoro) -# self.assertEqual("foo", resp.exit.result.error.type) -# self.assertEqual("bar", resp.exit.result.error.message) -# self.assertEqual(Status.THROTTLED, resp.status) - -# def test_tailcall(self): -# @self.dispatch.function -# def other_coroutine(value: Any) -> str: -# return f"Hello {value}" - -# @self.dispatch.primitive_function -# async def mycoro(input: Input) -> Output: -# return Output.tail_call(other_coroutine._build_primitive_call(42)) - -# resp = self.call(mycoro) -# self.assertEqual(other_coroutine.name, resp.exit.tail_call.function) -# self.assertEqual(42, any_unpickle(resp.exit.tail_call.input)) - -# def test_library_error_categorization(self): -# @self.dispatch.function -# def get(path: str) -> httpx.Response: -# http_response = self.http_client.get(path) -# http_response.raise_for_status() -# return http_response - -# resp = self.call(get, "/") -# self.assertEqual(Status.OK, Status(resp.status)) -# http_response = any_unpickle(resp.exit.result.output) -# self.assertEqual("application/json", http_response.headers["content-type"]) -# self.assertEqual('"OK"', http_response.text) - -# resp = self.call(get, "/missing") -# self.assertEqual(Status.NOT_FOUND, Status(resp.status)) - -# def test_library_output_categorization(self): -# @self.dispatch.function -# def get(path: str) -> httpx.Response: -# http_response = self.http_client.get(path) -# http_response.status_code = 429 -# return http_response - -# resp = self.call(get, "/") -# self.assertEqual(Status.THROTTLED, Status(resp.status)) -# http_response = any_unpickle(resp.exit.result.output) -# self.assertEqual("application/json", http_response.headers["content-type"]) -# self.assertEqual('"OK"', http_response.text) From fea85a874f749eadd02b181dbe15e536154c08c1 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 12:03:23 -0700 Subject: [PATCH 31/34] remove unused files Signed-off-by: Achille Roussel --- pyproject.toml | 12 +- src/buf/validate/expression_pb2_grpc.py | 3 - src/buf/validate/priv/private_pb2_grpc.py | 3 - src/buf/validate/validate_pb2_grpc.py | 3 - src/dispatch/function.py | 2 +- .../sdk/python/v1/pickled_pb2_grpc.py | 3 - src/dispatch/sdk/v1/call_pb2_grpc.py | 3 - src/dispatch/sdk/v1/dispatch_pb2_grpc.py | 94 ----- src/dispatch/sdk/v1/error_pb2_grpc.py | 3 - src/dispatch/sdk/v1/exit_pb2_grpc.py | 3 - src/dispatch/sdk/v1/function_pb2_grpc.py | 88 ----- src/dispatch/sdk/v1/poll_pb2_grpc.py | 3 - src/dispatch/sdk/v1/status_pb2_grpc.py | 3 - src/dispatch/{test/__init__.py => test.py} | 7 - src/dispatch/test/client.py | 155 -------- src/dispatch/test/fastapi.py | 10 - src/dispatch/test/flask.py | 46 --- src/dispatch/test/http.py | 34 -- src/dispatch/test/httpx.py | 39 -- src/dispatch/test/server.py | 61 --- src/dispatch/test/service.py | 362 ------------------ tests/test_fastapi.py | 26 -- 22 files changed, 4 insertions(+), 959 deletions(-) delete mode 100644 src/buf/validate/expression_pb2_grpc.py delete mode 100644 src/buf/validate/priv/private_pb2_grpc.py delete mode 100644 src/buf/validate/validate_pb2_grpc.py delete mode 100644 src/dispatch/sdk/python/v1/pickled_pb2_grpc.py delete mode 100644 src/dispatch/sdk/v1/call_pb2_grpc.py delete mode 100644 src/dispatch/sdk/v1/dispatch_pb2_grpc.py delete mode 100644 src/dispatch/sdk/v1/error_pb2_grpc.py delete mode 100644 src/dispatch/sdk/v1/exit_pb2_grpc.py delete mode 100644 src/dispatch/sdk/v1/function_pb2_grpc.py delete mode 100644 src/dispatch/sdk/v1/poll_pb2_grpc.py delete mode 100644 src/dispatch/sdk/v1/status_pb2_grpc.py rename src/dispatch/{test/__init__.py => test.py} (99%) delete mode 100644 src/dispatch/test/client.py delete mode 100644 src/dispatch/test/fastapi.py delete mode 100644 src/dispatch/test/flask.py delete mode 100644 src/dispatch/test/http.py delete mode 100644 src/dispatch/test/httpx.py delete mode 100644 src/dispatch/test/server.py delete mode 100644 src/dispatch/test/service.py diff --git a/pyproject.toml b/pyproject.toml index 2aff75b7..71acc6fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,18 +10,17 @@ dynamic = ["version"] requires-python = ">= 3.8" dependencies = [ "aiohttp >= 3.9.4", - "grpcio >= 1.60.0", "protobuf >= 4.24.0", "types-protobuf >= 4.24.0.20240129", - "grpc-stubs >= 1.53.0.5", "http-message-signatures >= 0.5.0", "tblib >= 3.0.0", "typing_extensions >= 4.10" ] [project.optional-dependencies] -fastapi = ["fastapi", "httpx"] +fastapi = ["fastapi"] flask = ["flask"] +httpx = ["httpx"] lambda = ["awslambdaric"] dev = [ @@ -60,17 +59,12 @@ profile = "black" src_paths = ["src"] [tool.coverage.run] -omit = ["*_pb2_grpc.py", "*_pb2.py", "tests/*", "examples/*", "src/buf/*"] +omit = ["*_pb2.py", "tests/*", "examples/*", "src/buf/*"] [tool.mypy] exclude = [ '^src/buf', '^tests/examples', - # mypy 1.10.0 reports false positives for these two files: - # src/dispatch/sdk/v1/function_pb2_grpc.py:74: error: Module has no attribute "experimental" [attr-defined] - # src/dispatch/sdk/v1/dispatch_pb2_grpc.py:80: error: Module has no attribute "experimental" [attr-defined] - '^src/dispatch/sdk/v1/function_pb2_grpc.py', - '^src/dispatch/sdk/v1/dispatch_pb2_grpc.py', ] [tool.pytest.ini_options] diff --git a/src/buf/validate/expression_pb2_grpc.py b/src/buf/validate/expression_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/buf/validate/expression_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/buf/validate/priv/private_pb2_grpc.py b/src/buf/validate/priv/private_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/buf/validate/priv/private_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/buf/validate/validate_pb2_grpc.py b/src/buf/validate/validate_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/buf/validate/validate_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/function.py b/src/dispatch/function.py index dc8a54aa..54925813 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -333,7 +333,7 @@ def _register(self, name: str, wrapped_func: PrimitiveFunction): def batch(self) -> Batch: """Returns a Batch instance that can be used to build a set of calls to dispatch.""" - return self.client.batch() + return Batch(self.client) _registries: Dict[str, Registry] = {} diff --git a/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py b/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/python/v1/pickled_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/call_pb2_grpc.py b/src/dispatch/sdk/v1/call_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/call_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/dispatch_pb2_grpc.py b/src/dispatch/sdk/v1/dispatch_pb2_grpc.py deleted file mode 100644 index 793cfbd3..00000000 --- a/src/dispatch/sdk/v1/dispatch_pb2_grpc.py +++ /dev/null @@ -1,94 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from dispatch.sdk.v1 import dispatch_pb2 as dispatch_dot_sdk_dot_v1_dot_dispatch__pb2 - - -class DispatchServiceStub(object): - """DispatchService is a service allowing the trigger of programmable endpoints - from a dispatch SDK. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Dispatch = channel.unary_unary( - "/dispatch.sdk.v1.DispatchService/Dispatch", - request_serializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchRequest.SerializeToString, - response_deserializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchResponse.FromString, - ) - - -class DispatchServiceServicer(object): - """DispatchService is a service allowing the trigger of programmable endpoints - from a dispatch SDK. - """ - - def Dispatch(self, request, context): - """Dispatch submits a list of asynchronous function calls to the service. - - The method does not wait for executions to complete before returning, - it only ensures that the creation was persisted, and returns unique - identifiers to represent the executions. - - The request contains a list of executions to be triggered; the method is - atomic, either all executions are recorded, or none and an error is - returned to explain the reason for the failure. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_DispatchServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - "Dispatch": grpc.unary_unary_rpc_method_handler( - servicer.Dispatch, - request_deserializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchRequest.FromString, - response_serializer=dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "dispatch.sdk.v1.DispatchService", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class DispatchService(object): - """DispatchService is a service allowing the trigger of programmable endpoints - from a dispatch SDK. - """ - - @staticmethod - def Dispatch( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/dispatch.sdk.v1.DispatchService/Dispatch", - dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchRequest.SerializeToString, - dispatch_dot_sdk_dot_v1_dot_dispatch__pb2.DispatchResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) diff --git a/src/dispatch/sdk/v1/error_pb2_grpc.py b/src/dispatch/sdk/v1/error_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/error_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/exit_pb2_grpc.py b/src/dispatch/sdk/v1/exit_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/exit_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/function_pb2_grpc.py b/src/dispatch/sdk/v1/function_pb2_grpc.py deleted file mode 100644 index 82193b36..00000000 --- a/src/dispatch/sdk/v1/function_pb2_grpc.py +++ /dev/null @@ -1,88 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -from dispatch.sdk.v1 import function_pb2 as dispatch_dot_sdk_dot_v1_dot_function__pb2 - - -class FunctionServiceStub(object): - """The FunctionService service is used to interface with programmable endpoints - exposing remote functions. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Run = channel.unary_unary( - "/dispatch.sdk.v1.FunctionService/Run", - request_serializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunRequest.SerializeToString, - response_deserializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunResponse.FromString, - ) - - -class FunctionServiceServicer(object): - """The FunctionService service is used to interface with programmable endpoints - exposing remote functions. - """ - - def Run(self, request, context): - """Run runs the function identified by the request, and returns a response - that either contains a result when the function completed, or a poll - directive and the associated coroutine state if the function was suspended. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_FunctionServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - "Run": grpc.unary_unary_rpc_method_handler( - servicer.Run, - request_deserializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunRequest.FromString, - response_serializer=dispatch_dot_sdk_dot_v1_dot_function__pb2.RunResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - "dispatch.sdk.v1.FunctionService", rpc_method_handlers - ) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class FunctionService(object): - """The FunctionService service is used to interface with programmable endpoints - exposing remote functions. - """ - - @staticmethod - def Run( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/dispatch.sdk.v1.FunctionService/Run", - dispatch_dot_sdk_dot_v1_dot_function__pb2.RunRequest.SerializeToString, - dispatch_dot_sdk_dot_v1_dot_function__pb2.RunResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) diff --git a/src/dispatch/sdk/v1/poll_pb2_grpc.py b/src/dispatch/sdk/v1/poll_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/poll_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/sdk/v1/status_pb2_grpc.py b/src/dispatch/sdk/v1/status_pb2_grpc.py deleted file mode 100644 index 8a939394..00000000 --- a/src/dispatch/sdk/v1/status_pb2_grpc.py +++ /dev/null @@ -1,3 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test.py similarity index 99% rename from src/dispatch/test/__init__.py rename to src/dispatch/test.py index 7250c96d..2ad018d8 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test.py @@ -40,14 +40,7 @@ STATUS_TLS_ERROR, ) -from .client import EndpointClient -from .server import DispatchServer -from .service import DispatchService - __all__ = [ - "EndpointClient", - "DispatchServer", - "DispatchService", "function", "method", "main", diff --git a/src/dispatch/test/client.py b/src/dispatch/test/client.py deleted file mode 100644 index 6ff3ba88..00000000 --- a/src/dispatch/test/client.py +++ /dev/null @@ -1,155 +0,0 @@ -from datetime import datetime -from typing import Optional - -import grpc - -from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.sdk.v1 import function_pb2_grpc as function_grpc -from dispatch.signature import ( - CaseInsensitiveDict, - Ed25519PrivateKey, - Request, - sign_request, -) -from dispatch.test.http import HttpClient - - -class EndpointClient: - """Test client for a Dispatch programmable endpoint. - - Note that this is different from dispatch.Client, which is a client - for the Dispatch API. The EndpointClient is a client similar to the one - that Dispatch itself would use to interact with an endpoint that provides - functions. - """ - - def __init__( - self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None - ): - """Initialize the client. - - Args: - http_client: Client to use to make HTTP requests. - signing_key: Optional Ed25519 private key to use to sign requests. - """ - channel = _HttpGrpcChannel(http_client, signing_key=signing_key) - self._stub = function_grpc.FunctionServiceStub(channel) - - def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse: - """Send a run request to an endpoint and return its response. - - Args: - request: A FunctionService Run request. - - Returns: - RunResponse: the response from the endpoint. - """ - return self._stub.Run(request) - - -class _HttpGrpcChannel(grpc.Channel): - def __init__( - self, http_client: HttpClient, signing_key: Optional[Ed25519PrivateKey] = None - ): - self.http_client = http_client - self.signing_key = signing_key - - def subscribe(self, callback, try_to_connect=False): - raise NotImplementedError() - - def unsubscribe(self, callback): - raise NotImplementedError() - - def unary_unary(self, method, request_serializer=None, response_deserializer=None): - return _UnaryUnaryMultiCallable( - self.http_client, - method, - request_serializer, - response_deserializer, - self.signing_key, - ) - - def unary_stream(self, method, request_serializer=None, response_deserializer=None): - raise NotImplementedError() - - def stream_unary(self, method, request_serializer=None, response_deserializer=None): - raise NotImplementedError() - - def stream_stream( - self, method, request_serializer=None, response_deserializer=None - ): - raise NotImplementedError() - - def close(self): - raise NotImplementedError() - - def __enter__(self): - raise NotImplementedError() - - def __exit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError() - - -class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): - def __init__( - self, - client, - method, - request_serializer, - response_deserializer, - signing_key: Optional[Ed25519PrivateKey] = None, - ): - self.client = client - self.method = method - self.request_serializer = request_serializer - self.response_deserializer = response_deserializer - self.signing_key = signing_key - - def __call__( - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, - ): - url = self.client.url_for(self.method) # note: method==path in gRPC parlance - - request = Request( - method="POST", - url=url, - body=self.request_serializer(request), - headers=CaseInsensitiveDict({"Content-Type": "application/grpc+proto"}), - ) - - if self.signing_key is not None: - sign_request(request, self.signing_key, datetime.now()) - - response = self.client.post( - request.url, body=request.body, headers=request.headers - ) - response.raise_for_status() - return self.response_deserializer(response.body) - - def with_call( - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, - ): - raise NotImplementedError() - - def future( - self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, - ): - raise NotImplementedError() diff --git a/src/dispatch/test/fastapi.py b/src/dispatch/test/fastapi.py deleted file mode 100644 index 381b1800..00000000 --- a/src/dispatch/test/fastapi.py +++ /dev/null @@ -1,10 +0,0 @@ -from fastapi import FastAPI -from fastapi.testclient import TestClient - -import dispatch.test.httpx -from dispatch.test.client import HttpClient - - -def http_client(app: FastAPI) -> HttpClient: - """Build a client for a FastAPI app.""" - return dispatch.test.httpx.Client(TestClient(app)) diff --git a/src/dispatch/test/flask.py b/src/dispatch/test/flask.py deleted file mode 100644 index e8cc3cbe..00000000 --- a/src/dispatch/test/flask.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Mapping - -import werkzeug.test -from flask import Flask - -from dispatch.test.http import HttpClient, HttpResponse - - -def http_client(app: Flask) -> HttpClient: - """Build a client for a Flask app.""" - return Client(app.test_client()) - - -class Client(HttpClient): - def __init__(self, client: werkzeug.test.Client): - self.client = client - - def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: - response = self.client.get(url, headers=headers.items()) - return Response(response) - - def post( - self, url: str, body: bytes, headers: Mapping[str, str] = {} - ) -> HttpResponse: - response = self.client.post(url, data=body, headers=headers.items()) - return Response(response) - - def url_for(self, path: str) -> str: - return "http://localhost" + path - - -class Response(HttpResponse): - def __init__(self, response): - self.response = response - - @property - def status_code(self): - return self.response.status_code - - @property - def body(self): - return self.response.data - - def raise_for_status(self): - if self.response.status_code // 100 != 2: - raise RuntimeError(f"HTTP status code {self.response.status_code}") diff --git a/src/dispatch/test/http.py b/src/dispatch/test/http.py deleted file mode 100644 index cf7ba9fa..00000000 --- a/src/dispatch/test/http.py +++ /dev/null @@ -1,34 +0,0 @@ -from dataclasses import dataclass -from typing import Mapping, Protocol - -import aiohttp - -from dispatch.function import Client as DefaultClient - - -@dataclass -class HttpResponse(Protocol): - status_code: int - body: bytes - - def raise_for_status(self): - """Raise an exception on non-2xx responses.""" - ... - - -class HttpClient(Protocol): - """Protocol for HTTP clients.""" - - def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: - """Make a GET request.""" - ... - - def post( - self, url: str, body: bytes, headers: Mapping[str, str] = {} - ) -> HttpResponse: - """Make a POST request.""" - ... - - def url_for(self, path: str) -> str: - """Get the fully-qualified URL for a path.""" - ... diff --git a/src/dispatch/test/httpx.py b/src/dispatch/test/httpx.py deleted file mode 100644 index 9d9f7c52..00000000 --- a/src/dispatch/test/httpx.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Mapping - -import httpx - -from dispatch.test.http import HttpClient, HttpResponse - - -class Client(HttpClient): - def __init__(self, client: httpx.Client): - self.client = client - - def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: - response = self.client.get(url, headers=headers) - return Response(response) - - def post( - self, url: str, body: bytes, headers: Mapping[str, str] = {} - ) -> HttpResponse: - response = self.client.post(url, content=body, headers=headers) - return Response(response) - - def url_for(self, path: str) -> str: - return str(httpx.URL(self.client.base_url).join(path)) - - -class Response(HttpResponse): - def __init__(self, response: httpx.Response): - self.response = response - - @property - def status_code(self): - return self.response.status_code - - @property - def body(self): - return self.response.content - - def raise_for_status(self): - self.response.raise_for_status() diff --git a/src/dispatch/test/server.py b/src/dispatch/test/server.py deleted file mode 100644 index a2d022b8..00000000 --- a/src/dispatch/test/server.py +++ /dev/null @@ -1,61 +0,0 @@ -import concurrent.futures -import sys - -import grpc - -from dispatch.sdk.v1 import dispatch_pb2_grpc as dispatch_grpc - - -class DispatchServer: - """Test server for a Dispatch service. This is useful for testing - a mock version of Dispatch locally (e.g. see - dispatch.test.DispatchService). - - Args: - service: Dispatch service to serve. - hostname: Hostname to bind to. - port: Port to bind to, or 0 to bind to any available port. - """ - - def __init__( - self, - service: dispatch_grpc.DispatchServiceServicer, - hostname: str = "127.0.0.1", - port: int = 0, - ): - self._thread_pool = concurrent.futures.thread.ThreadPoolExecutor() - self._server = grpc.server(self._thread_pool) - - self._hostname = hostname - self._port = self._server.add_insecure_port(f"{hostname}:{port}") - - dispatch_grpc.add_DispatchServiceServicer_to_server(service, self._server) - - @property - def url(self): - """Returns the URL of the server.""" - return f"http://{self._hostname}:{self._port}" - - def start(self): - """Start the server.""" - self._server.start() - - def wait(self): - """Block until the server terminates.""" - self._server.wait_for_termination() - - def stop(self): - """Stop the server.""" - self._server.stop(0) - self._server.wait_for_termination() - if sys.version_info >= (3, 9): - self._thread_pool.shutdown(wait=True, cancel_futures=True) - else: - self._thread_pool.shutdown(wait=True) - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py deleted file mode 100644 index ac23738b..00000000 --- a/src/dispatch/test/service.py +++ /dev/null @@ -1,362 +0,0 @@ -import enum -import logging -import os -import threading -import time -from collections import OrderedDict -from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple - -import grpc -from google.protobuf import any_pb2 as any_pb -from typing_extensions import TypeAlias - -import dispatch.sdk.v1.call_pb2 as call_pb -import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb -import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc -import dispatch.sdk.v1.function_pb2 as function_pb -import dispatch.sdk.v1.poll_pb2 as poll_pb -from dispatch.id import DispatchID -from dispatch.proto import CallResult, Error, Status -from dispatch.test import EndpointClient - -_default_retry_on_status = { - Status.THROTTLED, - Status.TIMEOUT, - Status.TEMPORARY_ERROR, - Status.DNS_ERROR, - Status.TCP_ERROR, - Status.TLS_ERROR, - Status.HTTP_ERROR, -} - - -logger = logging.getLogger(__name__) - - -RoundTrip: TypeAlias = Tuple[function_pb.RunRequest, function_pb.RunResponse] -"""A request to a Dispatch endpoint, and the response that was received.""" - - -class CallType(enum.Enum): - """Type of function call.""" - - CALL = 0 - RESUME = 1 - RETRY = 2 - - -class DispatchService(dispatch_grpc.DispatchServiceServicer): - """Test instance of Dispatch that provides the bare minimum - functionality required to test functions locally.""" - - def __init__( - self, - endpoint_client: EndpointClient, - api_key: Optional[str] = None, - retry_on_status: Optional[Set[Status]] = None, - collect_roundtrips: bool = False, - ): - """Initialize the Dispatch service. - - Args: - endpoint_client: Client to use to interact with the local Dispatch - endpoint (that provides the functions). - api_key: Expected API key on requests to the service. If omitted, the - value of the DISPATCH_API_KEY environment variable is used instead. - retry_on_status: Set of status codes to enable retries for. - collect_roundtrips: Enable collection of request/response round-trips - to the configured endpoint. - """ - super().__init__() - - self.endpoint_client = endpoint_client - - if api_key is None: - api_key = os.getenv("DISPATCH_API_KEY") - self.api_key = api_key - - if retry_on_status is None: - retry_on_status = _default_retry_on_status - self.retry_on_status = retry_on_status - - self._next_dispatch_id = 1 - - self.queue: List[Tuple[DispatchID, function_pb.RunRequest, CallType]] = [] - - self.pollers: Dict[DispatchID, Poller] = {} - self.parents: Dict[DispatchID, Poller] = {} - - self.roundtrips: OrderedDict[DispatchID, List[RoundTrip]] = OrderedDict() - self.collect_roundtrips = collect_roundtrips - - self._thread: Optional[threading.Thread] = None - self._stop_event = threading.Event() - self._work_signal = threading.Condition() - - def Dispatch(self, request: dispatch_pb.DispatchRequest, context): - """RPC handler for Dispatch requests. Requests are only queued for - processing here.""" - self._validate_authentication(context) - - resp = dispatch_pb.DispatchResponse() - - with self._work_signal: - for call in request.calls: - dispatch_id = self._make_dispatch_id() - logger.debug("enqueueing call to function: %s", call.function) - resp.dispatch_ids.append(dispatch_id) - run_request = function_pb.RunRequest( - function=call.function, - input=call.input, - dispatch_id=dispatch_id, - root_dispatch_id=dispatch_id, - ) - self.queue.append((dispatch_id, run_request, CallType.CALL)) - - self._work_signal.notify() - - return resp - - def _validate_authentication(self, context: grpc.ServicerContext): - expected = f"Bearer {self.api_key}" - for key, value in context.invocation_metadata(): - if key == "authorization": - if value == expected: - return - logger.warning( - "a client attempted to dispatch a function call with an incorrect API key. Is the client's DISPATCH_API_KEY correct?" - ) - context.abort( - grpc.StatusCode.UNAUTHENTICATED, - f"Invalid authorization header. Expected '{expected}', got {value!r}", - ) - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Missing authorization header") - - def _make_dispatch_id(self) -> DispatchID: - dispatch_id = self._next_dispatch_id - self._next_dispatch_id += 1 - return "{:032x}".format(dispatch_id) - - def dispatch_calls(self): - """Synchronously dispatch pending function calls to the - configured endpoint.""" - _next_queue: List[Tuple[DispatchID, function_pb.RunRequest, CallType]] = [] - while self.queue: - dispatch_id, request, call_type = self.queue.pop(0) - - if call_type == CallType.CALL: - logger.info("calling function %s", request.function) - elif call_type == CallType.RESUME: - logger.info("resuming function %s", request.function) - elif call_type == CallType.RETRY: - logger.info("retrying function %s", request.function) - - try: - response = self.endpoint_client.run(request) - except: - logger.warning("call to function %s failed", request.function) - self.queue.extend(_next_queue) - self.queue.append((dispatch_id, request, CallType.RETRY)) - raise - - if self.collect_roundtrips: - try: - roundtrips = self.roundtrips[dispatch_id] - except KeyError: - roundtrips = [] - - roundtrips.append((request, response)) - self.roundtrips[dispatch_id] = roundtrips - - status = Status(response.status) - if status == Status.OK: - logger.info("call to function %s succeeded", request.function) - else: - exc = None - if response.HasField("exit"): - if response.exit.HasField("result"): - result = response.exit.result - if result.HasField("error"): - exc = Error._from_proto(result.error).to_exception() - - if exc is not None: - logger.warning( - "call to function %s failed (%s => %s: %s)", - request.function, - status, - exc.__class__.__name__, - str(exc), - ) - else: - logger.warning( - "call to function %s failed (%s)", - request.function, - status, - ) - - if status in self.retry_on_status: - _next_queue.append((dispatch_id, request, CallType.RETRY)) - - elif response.HasField("poll"): - assert not response.HasField("exit") - - logger.info("suspending function %s", request.function) - - logger.debug("registering poller %s", dispatch_id) - - assert dispatch_id not in self.pollers - poller = Poller( - id=dispatch_id, - parent_id=request.parent_dispatch_id, - root_id=request.root_dispatch_id, - function=request.function, - typed_coroutine_state=response.poll.typed_coroutine_state, - waiting={}, - results={}, - ) - self.pollers[dispatch_id] = poller - - for call in response.poll.calls: - child_dispatch_id = self._make_dispatch_id() - child_request = function_pb.RunRequest( - function=call.function, - input=call.input, - dispatch_id=child_dispatch_id, - parent_dispatch_id=request.dispatch_id, - root_dispatch_id=request.root_dispatch_id, - ) - - _next_queue.append( - (child_dispatch_id, child_request, CallType.CALL) - ) - self.parents[child_dispatch_id] = poller - poller.waiting[child_dispatch_id] = call - - else: - assert response.HasField("exit") - - if response.exit.HasField("tail_call"): - tail_call = response.exit.tail_call - logger.debug( - "enqueueing tail call for %s", - tail_call.function, - ) - tail_call_request = function_pb.RunRequest( - function=tail_call.function, - input=tail_call.input, - dispatch_id=request.dispatch_id, - parent_dispatch_id=request.parent_dispatch_id, - root_dispatch_id=request.root_dispatch_id, - ) - _next_queue.append((dispatch_id, tail_call_request, CallType.CALL)) - - elif dispatch_id in self.parents: - result = response.exit.result - poller = self.parents[dispatch_id] - logger.debug( - "notifying poller %s of call result %s", poller.id, dispatch_id - ) - - call = poller.waiting[dispatch_id] - result.correlation_id = call.correlation_id - poller.results[dispatch_id] = result - del self.parents[dispatch_id] - del poller.waiting[dispatch_id] - - logger.debug( - "poller %s has %d waiting and %d ready results", - poller.id, - len(poller.waiting), - len(poller.results), - ) - - if not poller.waiting: - logger.debug( - "poller %s is now ready; enqueueing delivery of %d call result(s)", - poller.id, - len(poller.results), - ) - poll_results_request = function_pb.RunRequest( - dispatch_id=poller.id, - parent_dispatch_id=poller.parent_id, - root_dispatch_id=poller.root_id, - function=poller.function, - poll_result=poll_pb.PollResult( - typed_coroutine_state=poller.typed_coroutine_state, - results=poller.results.values(), - ), - ) - del self.pollers[poller.id] - _next_queue.append( - (poller.id, poll_results_request, CallType.RESUME) - ) - - self.queue = _next_queue - - def start(self): - """Start starts a background thread to continuously dispatch calls to the - configured endpoint.""" - if self._thread is not None: - raise RuntimeError("service has already been started") - - self._stop_event.clear() - self._thread = threading.Thread(target=self._dispatch_continuously) - self._thread.start() - - def stop(self): - """Stop stops the background thread that's dispatching calls to - the configured endpoint.""" - self._stop_event.set() - with self._work_signal: - self._work_signal.notify() - if self._thread is not None: - self._thread.join() - self._thread = None - - def _dispatch_continuously(self): - while True: - with self._work_signal: - if not self.queue and not self._stop_event.is_set(): - self._work_signal.wait() - - if self._stop_event.is_set(): - break - - try: - self.dispatch_calls() - except Exception as e: - logger.exception(e) - - # Introduce an artificial delay before continuing with - # follow-up work (retries, dispatching nested calls). - # This serves two purposes. Firstly, this is just a mock - # Dispatch server providing the bare minimum of functionality. - # Since there's no adaptive concurrency control, and no backoff - # between call attempts, the mock server may busy-loop without - # some sort of delay. Secondly, a bit of latency mimics the - # latency you would see in a production system and makes the - # log output easier to parse. - time.sleep(0.15) - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - -@dataclass -class Poller: - id: DispatchID - parent_id: DispatchID - root_id: DispatchID - - function: str - - typed_coroutine_state: any_pb.Any - # TODO: support max_wait/min_results/max_results - - waiting: Dict[DispatchID, call_pb.Call] - results: Dict[DispatchID, call_pb.CallResult] diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index be4964aa..554a032a 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -37,8 +37,6 @@ public_key_from_pem, ) from dispatch.status import Status -from dispatch.test import EndpointClient -from dispatch.test.fastapi import http_client class TestFastAPI(dispatch.test.TestCase): @@ -76,27 +74,3 @@ def dispatch_test_stop(self): loop = self.runner.get_loop() loop.call_soon_threadsafe(self.event.set) - -def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): - return Dispatch( - app, - registry=Registry( - name=__name__, - endpoint=endpoint, - client=Client( - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", - ), - ), - ) - - -def create_endpoint_client( - app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None -): - return EndpointClient(http_client(app), signing_key) - - -def response_output(resp: function_pb.RunResponse) -> Any: - return any_unpickle(resp.exit.result.output) - From 8394ee5b88aded099b43348bd3072216b0f79387 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 12:07:25 -0700 Subject: [PATCH 32/34] fix formatting Signed-off-by: Achille Roussel --- tests/test_fastapi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 554a032a..2af40d33 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -73,4 +73,3 @@ def dispatch_test_run(self): def dispatch_test_stop(self): loop = self.runner.get_loop() loop.call_soon_threadsafe(self.event.set) - From d448489ad8b9fb6c108700740279dd0419898794 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 12:18:47 -0700 Subject: [PATCH 33/34] add documentation Signed-off-by: Achille Roussel --- src/dispatch/__init__.py | 17 +++++++++++++ src/dispatch/http.py | 1 + src/dispatch/test.py | 55 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index ea0090a2..95f95ca2 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -65,6 +65,23 @@ def function(func): async def main(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T: + """Entrypoint of dispatch applications. This function creates a new + Dispatch server and runs the provided coroutine in the server's event loop. + + Programs typically don't use this function directly, unless they manage + their own event loop. Most of the time, the `run` function is a more + convenient way to run a dispatch application. + + Args: + coro: The coroutine to run as the entrypoint, the function returns + when the coroutine returns. + + addr: The address to bind the server to. If not provided, the server + will bind to the address specified by the `DISPATCH_ENDPOINT_ADDR` + + Returns: + The value returned by the coroutine. + """ address = addr or str(os.environ.get("DISPATCH_ENDPOINT_ADDR")) or "localhost:8000" parsed_url = urlsplit("//" + address) diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 55778bdc..591642be 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -79,6 +79,7 @@ def function(self, func): return self.registry.function(func) def batch(self) -> Batch: + """Create a new batch.""" return self.registry.batch() async def run( diff --git a/src/dispatch/test.py b/src/dispatch/test.py index 2ad018d8..b3768d14 100644 --- a/src/dispatch/test.py +++ b/src/dispatch/test.py @@ -273,6 +273,19 @@ def session(self) -> aiohttp.ClientSession: async def main(coro: Coroutine[Any, Any, None]) -> None: + """Entrypoint for dispatch function tests, which creates a local Dispatch + server and runs the provided coroutine in the event loop of the server. + + This is a low-level primitive that most test programs wouldn't use directly, + and instead would use one of the `function` or `method` decorators. + + Args: + coro: The coroutine to run as the entrypoint, the function returns + when the coroutine returns. + + Returns: + The value returned by the coroutine. + """ reg = default_registry() api = Service() app = Dispatch(reg) @@ -297,10 +310,48 @@ async def main(coro: Coroutine[Any, Any, None]) -> None: def run(coro: Coroutine[Any, Any, None]) -> None: + """Runs the provided coroutine in the test server's event loop. This + function is a convenience wrapper around the `main` function that runs the + coroutine in the event loop of the test server. + + Programs typically don't use this function directly, unless they manage + their own event loop. Most of the time, the `run` function is a more + convenient way to run a dispatch application. + + Args: + coro: The coroutine to run as the entrypoint, the function returns + when the coroutine returns. + + Returns: + The value returned by the coroutine. + """ return asyncio.run(main(coro)) def function(fn: Callable[[], Coroutine[Any, Any, None]]) -> Callable[[], None]: + """This decorator is used to write tests that execute in a local Dispatch + server. + + The decorated function would typically be a coroutine that implements the + test and returns when the test is done, for example: + + ```python + import dispatch + import dispatch.test + + @dispatch.function + def greet(name: str) -> str: + return f"Hello {name}!" + + @dispatch.test.function + async def test_greet(): + assert await greet("World") == "Hello World!" + ``` + + The test runs dispatch functions with the full dispatch capability, + including retrying temporary errors, etc... + """ + @wraps(fn) def wrapper(): return run(fn()) @@ -309,6 +360,10 @@ def wrapper(): def method(fn: Callable[[T], Coroutine[Any, Any, None]]) -> Callable[[T], None]: + """This decorator is similar to the `function` decorator but is intended to + apply to methods of a class (with a `self` value as first argument). + """ + @wraps(fn) def wrapper(self: T): return run(fn(self)) From e340938dfcaeb4bca36cd405d0579669a26234b0 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 18 Jun 2024 16:24:36 -0700 Subject: [PATCH 34/34] support both asyncio and blocking modes with different abstractions Signed-off-by: Achille Roussel --- src/dispatch/__init__.py | 2 +- .../{asyncio.py => asyncio/__init__.py} | 0 src/dispatch/asyncio/fastapi.py | 108 ++++++++++++++++++ src/dispatch/experimental/lambda_handler.py | 4 +- src/dispatch/fastapi.py | 95 +++------------ src/dispatch/flask.py | 8 +- src/dispatch/function.py | 54 +++++++-- src/dispatch/http.py | 51 +++++++-- src/dispatch/test.py | 7 +- tests/test_fastapi.py | 11 +- tests/test_http.py | 2 +- 11 files changed, 220 insertions(+), 122 deletions(-) rename src/dispatch/{asyncio.py => asyncio/__init__.py} (100%) create mode 100644 src/dispatch/asyncio/fastapi.py diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 95f95ca2..08ad5fe2 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -12,11 +12,11 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race +from dispatch.function import AsyncFunction as Function from dispatch.function import ( Batch, Client, ClientError, - Function, Registry, Reset, default_registry, diff --git a/src/dispatch/asyncio.py b/src/dispatch/asyncio/__init__.py similarity index 100% rename from src/dispatch/asyncio.py rename to src/dispatch/asyncio/__init__.py diff --git a/src/dispatch/asyncio/fastapi.py b/src/dispatch/asyncio/fastapi.py new file mode 100644 index 00000000..d7cf9f01 --- /dev/null +++ b/src/dispatch/asyncio/fastapi.py @@ -0,0 +1,108 @@ +"""Integration of Dispatch functions with FastAPI for handlers using asyncio. + +Example: + + import fastapi + from dispatch.asyncio.fastapi import Dispatch + + app = fastapi.FastAPI() + dispatch = Dispatch(app) + + @dispatch.function + def my_function(): + return "Hello World!" + + @app.get("/") + async def read_root(): + await my_function.dispatch() +""" + +import logging +from typing import Optional, Union + +import fastapi +import fastapi.responses + +from dispatch.function import Registry +from dispatch.http import ( + AsyncFunctionService, + FunctionServiceError, + validate_content_length, +) +from dispatch.signature import Ed25519PublicKey, parse_verification_key + +logger = logging.getLogger(__name__) + + +class Dispatch(AsyncFunctionService): + """A Dispatch instance, powered by FastAPI.""" + + def __init__( + self, + app: fastapi.FastAPI, + registry: Optional[Registry] = None, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + ): + """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. + + It mounts a sub-app that implements the Dispatch gRPC interface. + + Args: + app: The FastAPI app to configure. + + registry: A registry of functions to expose. If omitted, the default + registry is used. + + verification_key: Key to use when verifying signed requests. Uses + the value of the DISPATCH_VERIFICATION_KEY environment variable + if omitted. The environment variable is expected to carry an + Ed25519 public key in base64 or PEM format. + If not set, request signature verification is disabled (a warning + will be logged by the constructor). + + Raises: + ValueError: If any of the required arguments are missing. + """ + if not app: + raise ValueError( + "missing FastAPI app as first argument of the Dispatch constructor" + ) + super().__init__(registry, verification_key) + function_service = fastapi.FastAPI() + + @function_service.exception_handler(FunctionServiceError) + async def on_error(request: fastapi.Request, exc: FunctionServiceError): + # https://connectrpc.com/docs/protocol/#error-end-stream + return fastapi.responses.JSONResponse( + status_code=exc.status, + content={"code": exc.code, "message": exc.message}, + ) + + @function_service.post( + # The endpoint for execution is hardcoded at the moment. If the service + # gains more endpoints, this should be turned into a dynamic dispatch + # like the official gRPC server does. + "/Run", + ) + async def run(request: fastapi.Request): + valid, reason = validate_content_length( + int(request.headers.get("content-length", 0)) + ) + if not valid: + raise FunctionServiceError(400, "invalid_argument", reason) + + # Raw request body bytes are only available through the underlying + # starlette Request object's body method, which returns an awaitable, + # forcing execute() to be async. + data: bytes = await request.body() + + content = await self.run( + str(request.url), + request.method, + request.headers, + await request.body(), + ) + + return fastapi.Response(content=content, media_type="application/proto") + + app.mount("/dispatch.sdk.v1.FunctionService", function_service) diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 8990c6a1..01a0968d 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -27,7 +27,7 @@ def handler(event, context): from awslambdaric.lambda_context import LambdaContext from dispatch.function import Registry -from dispatch.http import FunctionService +from dispatch.http import BlockingFunctionService from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.status import Status @@ -35,7 +35,7 @@ def handler(event, context): logger = logging.getLogger(__name__) -class Dispatch(FunctionService): +class Dispatch(BlockingFunctionService): def __init__( self, registry: Optional[Registry] = None, diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 3abf7b1a..7bf75f1f 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -15,90 +15,29 @@ def my_function(): @app.get("/") def read_root(): my_function.dispatch() - """ +""" -import logging -from typing import Optional, Union +from typing import Any, Callable, Coroutine, TypeVar, overload -import fastapi -import fastapi.responses +from typing_extensions import ParamSpec -from dispatch.function import Registry -from dispatch.http import FunctionService, FunctionServiceError, validate_content_length -from dispatch.signature import Ed25519PublicKey, parse_verification_key +from dispatch.asyncio.fastapi import Dispatch as AsyncDispatch +from dispatch.function import BlockingFunction -logger = logging.getLogger(__name__) +__all__ = ["Dispatch", "AsyncDispatch"] +P = ParamSpec("P") +T = TypeVar("T") -class Dispatch(FunctionService): - """A Dispatch instance, powered by FastAPI.""" - def __init__( - self, - app: fastapi.FastAPI, - registry: Optional[Registry] = None, - verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, - ): - """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. +class Dispatch(AsyncDispatch): + @overload # type: ignore + def function(self, func: Callable[P, T]) -> BlockingFunction[P, T]: ... - It mounts a sub-app that implements the Dispatch gRPC interface. + @overload # type: ignore + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> BlockingFunction[P, T]: ... - Args: - app: The FastAPI app to configure. - - registry: A registry of functions to expose. If omitted, the default - registry is used. - - verification_key: Key to use when verifying signed requests. Uses - the value of the DISPATCH_VERIFICATION_KEY environment variable - if omitted. The environment variable is expected to carry an - Ed25519 public key in base64 or PEM format. - If not set, request signature verification is disabled (a warning - will be logged by the constructor). - - Raises: - ValueError: If any of the required arguments are missing. - """ - if not app: - raise ValueError( - "missing FastAPI app as first argument of the Dispatch constructor" - ) - super().__init__(registry, verification_key) - function_service = fastapi.FastAPI() - - @function_service.exception_handler(FunctionServiceError) - async def on_error(request: fastapi.Request, exc: FunctionServiceError): - # https://connectrpc.com/docs/protocol/#error-end-stream - return fastapi.responses.JSONResponse( - status_code=exc.status, - content={"code": exc.code, "message": exc.message}, - ) - - @function_service.post( - # The endpoint for execution is hardcoded at the moment. If the service - # gains more endpoints, this should be turned into a dynamic dispatch - # like the official gRPC server does. - "/Run", - ) - async def run(request: fastapi.Request): - valid, reason = validate_content_length( - int(request.headers.get("content-length", 0)) - ) - if not valid: - raise FunctionServiceError(400, "invalid_argument", reason) - - # Raw request body bytes are only available through the underlying - # starlette Request object's body method, which returns an awaitable, - # forcing execute() to be async. - data: bytes = await request.body() - - content = await self.run( - str(request.url), - request.method, - request.headers, - await request.body(), - ) - - return fastapi.Response(content=content, media_type="application/proto") - - app.mount("/dispatch.sdk.v1.FunctionService", function_service) + def function(self, func): + return BlockingFunction(super().function(func)) diff --git a/src/dispatch/flask.py b/src/dispatch/flask.py index ffd6c923..0724b899 100644 --- a/src/dispatch/flask.py +++ b/src/dispatch/flask.py @@ -24,13 +24,17 @@ def read_root(): from flask import Flask, make_response, request from dispatch.function import Registry -from dispatch.http import FunctionService, FunctionServiceError, validate_content_length +from dispatch.http import ( + BlockingFunctionService, + FunctionServiceError, + validate_content_length, +) from dispatch.signature import Ed25519PublicKey, parse_verification_key logger = logging.getLogger(__name__) -class Dispatch(FunctionService): +class Dispatch(BlockingFunctionService): """A Dispatch instance, powered by Flask.""" def __init__( diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 54925813..a685a1d0 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -18,6 +18,7 @@ Optional, Tuple, TypeVar, + Union, overload, ) from urllib.parse import urlparse @@ -111,7 +112,7 @@ def _build_primitive_call( ) -class Function(PrimitiveFunction, Generic[P, T]): +class AsyncFunction(PrimitiveFunction, Generic[P, T]): """Callable wrapper around a function meant to be used throughout the Dispatch Python SDK. """ @@ -157,7 +158,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: else: return self._call_dispatch(*args, **kwargs) - def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: + async def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: """Dispatch an asynchronous call to the function without waiting for a result. @@ -171,7 +172,7 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: Returns: DispatchID: ID of the dispatched call. """ - return asyncio.run(self._primitive_dispatch(Arguments(args, kwargs))) + return await self._primitive_dispatch(Arguments(args, kwargs)) def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call: """Create a Call for this function with the provided input. Useful to @@ -187,11 +188,38 @@ def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call: return self._build_primitive_call(Arguments(args, kwargs)) +class BlockingFunction(Generic[P, T]): + """BlockingFunction is like Function but exposes a blocking API instead of + functions that use asyncio. + + Applications typically don't create instances of BlockingFunction directly, + and instead use decorators from packages that provide integrations with + Python frameworks. + """ + + def __init__(self, func: AsyncFunction[P, T]): + self._func = func + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return asyncio.run(self._func(*args, **kwargs)) + + def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: + return asyncio.run(self._func.dispatch(*args, **kwargs)) + + def build_call(self, *args: P.args, **kwargs: P.kwargs) -> Call: + return self._func.build_call(*args, **kwargs) + + class Reset(TailCall): """The current coroutine is aborted and scheduling reset to be replaced with the call embedded in this exception.""" - def __init__(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs): + def __init__( + self, + func: Union[AsyncFunction[P, T], BlockingFunction[P, T]], + *args: P.args, + **kwargs: P.kwargs, + ): super().__init__(call=func.build_call(*args, **kwargs)) @@ -267,10 +295,12 @@ def endpoint(self, value: str): self._endpoint = value @overload - def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> AsyncFunction[P, T]: ... @overload - def function(self, func: Callable[P, T]) -> Function[P, T]: ... + def function(self, func: Callable[P, T]) -> AsyncFunction[P, T]: ... def function(self, func): """Decorator that registers functions.""" @@ -283,7 +313,9 @@ def function(self, func): logger.info("registering coroutine: %s", name) return self._register_coroutine(name, func) - def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: + def _register_function( + self, name: str, func: Callable[P, T] + ) -> AsyncFunction[P, T]: func = durable(func) @wraps(func) @@ -296,7 +328,7 @@ async def asyncio_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def _register_coroutine( self, name: str, func: Callable[P, Coroutine[Any, Any, T]] - ) -> Function[P, T]: + ) -> AsyncFunction[P, T]: logger.info("registering coroutine: %s", name) func = durable(func) @@ -307,7 +339,7 @@ async def primitive_func(input: Input) -> Output: primitive_func.__qualname__ = f"{name}_primitive" durable_primitive_func = durable(primitive_func) - wrapped_func = Function[P, T]( + wrapped_func = AsyncFunction[P, T]( self, name, durable_primitive_func, @@ -555,7 +587,9 @@ def __init__(self, client: Client): self.client = client self.calls = [] - def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch: + def add( + self, func: AsyncFunction[P, T], *args: P.args, **kwargs: P.kwargs + ) -> Batch: """Add a call to the specified function to the batch.""" return self.add_call(func.build_call(*args, **kwargs)) diff --git a/src/dispatch/http.py b/src/dispatch/http.py index 591642be..e56cb265 100644 --- a/src/dispatch/http.py +++ b/src/dispatch/http.py @@ -20,7 +20,14 @@ from http_message_signatures import InvalidSignature from typing_extensions import ParamSpec, TypeAlias -from dispatch.function import Batch, Function, Registry, _calls, default_registry +from dispatch.function import ( + AsyncFunction, + Batch, + BlockingFunction, + Registry, + _calls, + default_registry, +) from dispatch.proto import CallResult, Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -41,7 +48,7 @@ T = TypeVar("T") -class FunctionService: +class BaseFunctionService: """FunctionService is an abstract class intended to be inherited by objects that integrate dispatch with other server application frameworks. @@ -68,16 +75,6 @@ def registry(self) -> Registry: def verification_key(self) -> Optional[Ed25519PublicKey]: return self._verification_key - @overload - def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... - - @overload - def function(self, func: Callable[P, T]) -> Function[P, T]: ... - - def function(self, func): - """Decorator that registers functions.""" - return self.registry.function(func) - def batch(self) -> Batch: """Create a new batch.""" return self.registry.batch() @@ -95,6 +92,36 @@ async def run( ) +class AsyncFunctionService(BaseFunctionService): + @overload + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> AsyncFunction[P, T]: ... + + @overload + def function(self, func: Callable[P, T]) -> AsyncFunction[P, T]: ... + + def function(self, func): + return self.registry.function(func) + + +class BlockingFunctionService(BaseFunctionService): + """BlockingFunctionService is a variant of FunctionService which decorates + dispatch functions with a synchronous API instead of using asyncio. + """ + + @overload + def function(self, func: Callable[P, T]) -> BlockingFunction[P, T]: ... + + @overload + def function( + self, func: Callable[P, Coroutine[Any, Any, T]] + ) -> BlockingFunction[P, T]: ... + + def function(self, func): + return BlockingFunction(self.registry.function(func)) + + class FunctionServiceError(Exception): __slots__ = ("status", "code", "message") diff --git a/src/dispatch/test.py b/src/dispatch/test.py index b3768d14..7c8291b8 100644 --- a/src/dispatch/test.py +++ b/src/dispatch/test.py @@ -21,7 +21,7 @@ default_registry, set_default_registry, ) -from dispatch.http import Dispatch, FunctionService +from dispatch.http import Dispatch from dispatch.http import Server as BaseServer from dispatch.sdk.v1.call_pb2 import Call, CallResult from dispatch.sdk.v1.dispatch_pb2 import DispatchRequest, DispatchResponse @@ -290,19 +290,14 @@ async def main(coro: Coroutine[Any, Any, None]) -> None: api = Service() app = Dispatch(reg) try: - print("Starting bakend") async with Server(api) as backend: - print("Starting server") async with Server(app) as server: # Here we break through the abstraction layers a bit, it's not # ideal but it works for now. reg.client.api_url.value = backend.url reg.endpoint = server.url - print("BACKEND:", backend.url) - print("SERVER:", server.url) await coro finally: - print("DONE!") await api.close() # TODO: let's figure out how to get rid of this global registry # state at some point, which forces tests to be run sequentially. diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 2af40d33..97e9766f 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,7 +1,6 @@ import asyncio import socket import sys -from typing import Any, Optional import fastapi import google.protobuf.any_pb2 @@ -19,15 +18,7 @@ from dispatch.asyncio import Runner from dispatch.experimental.durable.registry import clear_functions from dispatch.fastapi import Dispatch -from dispatch.function import ( - Arguments, - Client, - Error, - Function, - Input, - Output, - Registry, -) +from dispatch.function import Arguments, Client, Error, Input, Output, Registry from dispatch.proto import _any_unpickle as any_unpickle from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import function_pb2 as function_pb diff --git a/tests/test_http.py b/tests/test_http.py index cba4742b..cdfc277b 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -6,7 +6,7 @@ import dispatch.test from dispatch.asyncio import Runner from dispatch.function import Registry -from dispatch.http import Dispatch, FunctionService, Server +from dispatch.http import Dispatch, Server class TestHTTP(dispatch.test.TestCase):