diff --git a/examples/github_stats/test_app.py b/examples/github_stats/test_app.py index 8b5fd916..56c69ccd 100644 --- a/examples/github_stats/test_app.py +++ b/examples/github_stats/test_app.py @@ -8,7 +8,7 @@ from fastapi.testclient import TestClient -from dispatch.client import Client +from dispatch.function import Client from dispatch.test import DispatchServer, DispatchService, EndpointClient diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 2b2f5ff7..a9affbda 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -3,8 +3,8 @@ from __future__ import annotations import dispatch.integrations -from dispatch.client import DEFAULT_API_URL, Client from dispatch.coroutine import call, gather +from dispatch.function import DEFAULT_API_URL, Client from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output from dispatch.status import Status diff --git a/src/dispatch/client.py b/src/dispatch/client.py deleted file mode 100644 index 9f67490b..00000000 --- a/src/dispatch/client.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -import logging -import os -from typing import Iterable -from urllib.parse import urlparse - -import grpc - -import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb -import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc -from dispatch.id import DispatchID -from dispatch.proto import Call - -logger = logging.getLogger(__name__) - - -DEFAULT_API_URL = "https://api.dispatch.run" - - -class Client: - """Client for the Dispatch API.""" - - __slots__ = ("api_url", "api_key", "_stub", "api_key_from") - - def __init__(self, api_key: None | str = None, api_url: None | str = None): - """Create a new Dispatch client. - - Args: - api_key: Dispatch API key to use for authentication. Uses the value of - the DISPATCH_API_KEY environment variable by default. - - api_url: The URL of the Dispatch API to use. Uses the value of the - DISPATCH_API_URL environment variable if set, otherwise - defaults to the public Dispatch API (DEFAULT_API_URL). - - Raises: - ValueError: if the API key is missing. - """ - - if api_key: - self.api_key_from = "api_key" - else: - self.api_key_from = "DISPATCH_API_KEY" - api_key = os.environ.get("DISPATCH_API_KEY") - if not api_key: - raise ValueError( - "missing API key: set it with the DISPATCH_API_KEY environment variable" - ) - - if not api_url: - api_url = os.environ.get("DISPATCH_API_URL", DEFAULT_API_URL) - if not api_url: - raise ValueError( - "missing API URL: set it with the DISPATCH_API_URL environment variable" - ) - - logger.debug("initializing client for Dispatch API at URL %s", api_url) - self.api_url = api_url - self.api_key = api_key - self._init_stub() - - def __getstate__(self): - return {"api_url": self.api_url, "api_key": self.api_key} - - def __setstate__(self, state): - self.api_url = state["api_url"] - self.api_key = state["api_key"] - self._init_stub() - - def _init_stub(self): - result = urlparse(self.api_url) - match result.scheme: - case "http": - creds = grpc.local_channel_credentials() - case "https": - creds = grpc.ssl_channel_credentials() - case _: - raise ValueError(f"Invalid API scheme: '{result.scheme}'") - - call_creds = grpc.access_token_call_credentials(self.api_key) - creds = grpc.composite_channel_credentials(creds, call_creds) - channel = grpc.secure_channel(result.netloc, creds) - - self._stub = dispatch_grpc.DispatchServiceStub(channel) - - def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]: - """Dispatch function calls. - - Args: - calls: Calls to dispatch. - - Returns: - Identifiers for the function calls, in the same order as the inputs. - """ - calls_proto = [c._as_proto() for c in calls] - logger.debug("dispatching %d function call(s)", len(calls_proto)) - req = dispatch_pb.DispatchRequest(calls=calls_proto) - - try: - resp = self._stub.Dispatch(req) - except grpc.RpcError as e: - status_code = e.code() - match status_code: - case grpc.StatusCode.UNAUTHENTICATED: - raise PermissionError( - f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" - ) from e - raise - - dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "dispatched %d function call(s): %s", - len(calls_proto), - ", ".join(dispatch_ids), - ) - return dispatch_ids diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 2e29260b..97c4f7ed 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -27,8 +27,7 @@ def read_root(): import fastapi.responses from http_message_signatures import InvalidSignature -from dispatch.client import Client -from dispatch.function import Registry +from dispatch.function import Batch, Client, Registry from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( @@ -47,7 +46,7 @@ def read_root(): class Dispatch(Registry): """A Dispatch programmable endpoint, powered by FastAPI.""" - __slots__ = () + __slots__ = ("client",) def __init__( self, @@ -116,12 +115,17 @@ def __init__( "request verification is disabled because DISPATCH_VERIFICATION_KEY is not set" ) - client = Client(api_key=api_key, api_url=api_url) - super().__init__(endpoint, client) + self.client = Client(api_key=api_key, api_url=api_url) + super().__init__(endpoint, self.client) function_service = _new_app(self, verification_key) app.mount("/dispatch.sdk.v1.FunctionService", function_service) + def batch(self) -> Batch: + """Returns a Batch instance that can be used to build + a set of calls to dispatch.""" + return self.client.batch() + def parse_verification_key( verification_key: Ed25519PublicKey | str | bytes | None, diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 03df30b2..ba9475e0 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -2,6 +2,7 @@ import inspect import logging +import os from functools import wraps from types import CoroutineType from typing import ( @@ -10,14 +11,19 @@ Coroutine, Dict, Generic, + Iterable, ParamSpec, TypeAlias, TypeVar, overload, ) +from urllib.parse import urlparse + +import grpc import dispatch.coroutine -from dispatch.client import Client +import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb +import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc from dispatch.experimental.durable import durable from dispatch.id import DispatchID from dispatch.proto import Arguments, Call, Error, Input, Output @@ -33,6 +39,9 @@ """ +DEFAULT_API_URL = "https://api.dispatch.run" + + class PrimitiveFunction: __slots__ = ("_endpoint", "_client", "_name", "_primitive_func") @@ -234,3 +243,147 @@ def set_client(self, client: Client): self._client = client for fn in self._functions.values(): fn._client = client + + +class Client: + """Client for the Dispatch API.""" + + __slots__ = ("api_url", "api_key", "_stub", "api_key_from") + + def __init__(self, api_key: None | str = None, api_url: None | str = None): + """Create a new Dispatch client. + + Args: + api_key: Dispatch API key to use for authentication. Uses the value of + the DISPATCH_API_KEY environment variable by default. + + api_url: The URL of the Dispatch API to use. Uses the value of the + DISPATCH_API_URL environment variable if set, otherwise + defaults to the public Dispatch API (DEFAULT_API_URL). + + Raises: + ValueError: if the API key is missing. + """ + + if api_key: + self.api_key_from = "api_key" + else: + self.api_key_from = "DISPATCH_API_KEY" + api_key = os.environ.get("DISPATCH_API_KEY") + if not api_key: + raise ValueError( + "missing API key: set it with the DISPATCH_API_KEY environment variable" + ) + + if not api_url: + api_url = os.environ.get("DISPATCH_API_URL", DEFAULT_API_URL) + if not api_url: + raise ValueError( + "missing API URL: set it with the DISPATCH_API_URL environment variable" + ) + + logger.debug("initializing client for Dispatch API at URL %s", api_url) + self.api_url = api_url + self.api_key = api_key + self._init_stub() + + def __getstate__(self): + return {"api_url": self.api_url, "api_key": self.api_key} + + def __setstate__(self, state): + self.api_url = state["api_url"] + self.api_key = state["api_key"] + self._init_stub() + + def _init_stub(self): + result = urlparse(self.api_url) + match result.scheme: + case "http": + creds = grpc.local_channel_credentials() + case "https": + creds = grpc.ssl_channel_credentials() + case _: + raise ValueError(f"Invalid API scheme: '{result.scheme}'") + + call_creds = grpc.access_token_call_credentials(self.api_key) + creds = grpc.composite_channel_credentials(creds, call_creds) + channel = grpc.secure_channel(result.netloc, creds) + + self._stub = dispatch_grpc.DispatchServiceStub(channel) + + def batch(self) -> Batch: + """Returns a Batch instance that can be used to build + a set of calls to dispatch.""" + return Batch(self) + + def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]: + """Dispatch function calls. + + Args: + calls: Calls to dispatch. + + Returns: + Identifiers for the function calls, in the same order as the inputs. + """ + calls_proto = [c._as_proto() for c in calls] + logger.debug("dispatching %d function call(s)", len(calls_proto)) + req = dispatch_pb.DispatchRequest(calls=calls_proto) + + try: + resp = self._stub.Dispatch(req) + except grpc.RpcError as e: + status_code = e.code() + match status_code: + case grpc.StatusCode.UNAUTHENTICATED: + raise PermissionError( + f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" + ) from e + raise + + dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "dispatched %d function call(s): %s", + len(calls_proto), + ", ".join(dispatch_ids), + ) + return dispatch_ids + + +class Batch: + """A batch of calls to dispatch.""" + + __slots__ = ("client", "calls") + + def __init__(self, client: Client): + self.client = client + self.calls: list[Call] = [] + + def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch: + """Add a call to the specified function to the batch.""" + return self.add_call(func.build_call(*args, correlation_id=None, **kwargs)) + + def add_call(self, call: Call) -> Batch: + """Add a Call to the batch.""" + self.calls.append(call) + return self + + def dispatch(self) -> list[DispatchID]: + """Dispatch dispatches the calls asynchronously. + + The batch is reset when the calls are dispatched successfully. + + Returns: + Identifiers for the function calls, in the same order they + were added. + """ + if not self.calls: + return [] + + dispatch_ids = self.client.dispatch(self.calls) + self.reset() + return dispatch_ids + + def reset(self): + """Reset the batch.""" + self.calls = [] diff --git a/tests/dispatch/test_function.py b/tests/dispatch/test_function.py index e9ed5954..6f4a93ab 100644 --- a/tests/dispatch/test_function.py +++ b/tests/dispatch/test_function.py @@ -1,8 +1,7 @@ import pickle import unittest -from dispatch.client import Client -from dispatch.function import Registry +from dispatch.function import Client, Registry class TestFunction(unittest.TestCase): diff --git a/tests/test_client.py b/tests/test_client.py index 29a2ddec..09f96b59 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -71,3 +71,28 @@ def test_call_pickle(self): self.assertEqual(dispatch_id, dispatch_ids[0]) self.assertEqual(call.function, "my-function") self.assertEqual(any_unpickle(call.input), 42) + + def test_batch(self): + batch = self.dispatch_client.batch() + + batch.add_call(Call(function="my-function", input=42)).add_call( + Call(function="my-function", input=23) + ).add_call(Call(function="my-function2", input=11)) + + dispatch_ids = batch.dispatch() + self.assertEqual(len(dispatch_ids), 3) + + pending_calls = self.dispatch_service.queue + self.assertEqual(len(pending_calls), 3) + dispatch_id0, call0, _ = pending_calls[0] + dispatch_id1, call1, _ = pending_calls[1] + dispatch_id2, call2, _ = pending_calls[2] + self.assertListEqual(dispatch_ids, [dispatch_id0, dispatch_id1, dispatch_id2]) + self.assertEqual(call0.function, "my-function") + self.assertEqual(any_unpickle(call0.input), 42) + self.assertEqual(call1.function, "my-function") + self.assertEqual(any_unpickle(call1.input), 23) + self.assertEqual(call2.function, "my-function2") + self.assertEqual(any_unpickle(call2.input), 11) + + self.assertEqual(len(batch.dispatch()), 0) # batch was reset