Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch submit #127

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fastapi.testclient import TestClient

from dispatch.client import Client
from dispatch.function import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient


Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

import dispatch.integrations
from dispatch.client import DEFAULT_API_URL, Client
from dispatch.coroutine import call, gather
from dispatch.function import DEFAULT_API_URL, Client
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
from dispatch.status import Status
Expand Down
118 changes: 0 additions & 118 deletions src/dispatch/client.py

This file was deleted.

14 changes: 9 additions & 5 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def read_root():
import fastapi.responses
from http_message_signatures import InvalidSignature

from dispatch.client import Client
from dispatch.function import Registry
from dispatch.function import Batch, Client, Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
Expand All @@ -47,7 +46,7 @@ def read_root():
class Dispatch(Registry):
"""A Dispatch programmable endpoint, powered by FastAPI."""

__slots__ = ()
__slots__ = ("client",)

def __init__(
self,
Expand Down Expand Up @@ -116,12 +115,17 @@ def __init__(
"request verification is disabled because DISPATCH_VERIFICATION_KEY is not set"
)

client = Client(api_key=api_key, api_url=api_url)
super().__init__(endpoint, client)
self.client = Client(api_key=api_key, api_url=api_url)
super().__init__(endpoint, self.client)

function_service = _new_app(self, verification_key)
app.mount("/dispatch.sdk.v1.FunctionService", function_service)

def batch(self) -> Batch:
"""Returns a Batch instance that can be used to build
a set of calls to dispatch."""
return self.client.batch()


def parse_verification_key(
verification_key: Ed25519PublicKey | str | bytes | None,
Expand Down
155 changes: 154 additions & 1 deletion src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import logging
import os
from functools import wraps
from types import CoroutineType
from typing import (
Expand All @@ -10,14 +11,19 @@
Coroutine,
Dict,
Generic,
Iterable,
ParamSpec,
TypeAlias,
TypeVar,
overload,
)
from urllib.parse import urlparse

import grpc

import dispatch.coroutine
from dispatch.client import Client
import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb
import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc
from dispatch.experimental.durable import durable
from dispatch.id import DispatchID
from dispatch.proto import Arguments, Call, Error, Input, Output
Expand All @@ -33,6 +39,9 @@
"""


DEFAULT_API_URL = "https://api.dispatch.run"


class PrimitiveFunction:
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func")

Expand Down Expand Up @@ -234,3 +243,147 @@ def set_client(self, client: Client):
self._client = client
for fn in self._functions.values():
fn._client = client


class Client:
"""Client for the Dispatch API."""

__slots__ = ("api_url", "api_key", "_stub", "api_key_from")

def __init__(self, api_key: None | str = None, api_url: None | str = None):
"""Create a new Dispatch client.

Args:
api_key: Dispatch API key to use for authentication. Uses the value of
the DISPATCH_API_KEY environment variable by default.

api_url: The URL of the Dispatch API to use. Uses the value of the
DISPATCH_API_URL environment variable if set, otherwise
defaults to the public Dispatch API (DEFAULT_API_URL).

Raises:
ValueError: if the API key is missing.
"""

if api_key:
self.api_key_from = "api_key"
else:
self.api_key_from = "DISPATCH_API_KEY"
api_key = os.environ.get("DISPATCH_API_KEY")
if not api_key:
raise ValueError(
"missing API key: set it with the DISPATCH_API_KEY environment variable"
)

if not api_url:
api_url = os.environ.get("DISPATCH_API_URL", DEFAULT_API_URL)
if not api_url:
raise ValueError(
"missing API URL: set it with the DISPATCH_API_URL environment variable"
)

logger.debug("initializing client for Dispatch API at URL %s", api_url)
self.api_url = api_url
self.api_key = api_key
self._init_stub()

def __getstate__(self):
return {"api_url": self.api_url, "api_key": self.api_key}

def __setstate__(self, state):
self.api_url = state["api_url"]
self.api_key = state["api_key"]
self._init_stub()

def _init_stub(self):
result = urlparse(self.api_url)
match result.scheme:
case "http":
creds = grpc.local_channel_credentials()
case "https":
creds = grpc.ssl_channel_credentials()
case _:
raise ValueError(f"Invalid API scheme: '{result.scheme}'")

call_creds = grpc.access_token_call_credentials(self.api_key)
creds = grpc.composite_channel_credentials(creds, call_creds)
channel = grpc.secure_channel(result.netloc, creds)

self._stub = dispatch_grpc.DispatchServiceStub(channel)

def batch(self) -> Batch:
"""Returns a Batch instance that can be used to build
a set of calls to dispatch."""
return Batch(self)

def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]:
"""Dispatch function calls.

Args:
calls: Calls to dispatch.

Returns:
Identifiers for the function calls, in the same order as the inputs.
"""
calls_proto = [c._as_proto() for c in calls]
logger.debug("dispatching %d function call(s)", len(calls_proto))
req = dispatch_pb.DispatchRequest(calls=calls_proto)

try:
resp = self._stub.Dispatch(req)
except grpc.RpcError as e:
status_code = e.code()
match status_code:
case grpc.StatusCode.UNAUTHENTICATED:
raise PermissionError(
f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)"
) from e
raise

dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"dispatched %d function call(s): %s",
len(calls_proto),
", ".join(dispatch_ids),
)
return dispatch_ids


class Batch:
"""A batch of calls to dispatch."""

__slots__ = ("client", "calls")

def __init__(self, client: Client):
self.client = client
self.calls: list[Call] = []

def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch:
"""Add a call to the specified function to the batch."""
return self.add_call(func.build_call(*args, correlation_id=None, **kwargs))

def add_call(self, call: Call) -> Batch:
"""Add a Call to the batch."""
self.calls.append(call)
return self

def dispatch(self) -> list[DispatchID]:
"""Dispatch dispatches the calls asynchronously.

The batch is reset when the calls are dispatched successfully.

Returns:
Identifiers for the function calls, in the same order they
were added.
"""
if not self.calls:
return []

dispatch_ids = self.client.dispatch(self.calls)
self.reset()
return dispatch_ids

def reset(self):
"""Reset the batch."""
self.calls = []
3 changes: 1 addition & 2 deletions tests/dispatch/test_function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pickle
import unittest

from dispatch.client import Client
from dispatch.function import Registry
from dispatch.function import Client, Registry


class TestFunction(unittest.TestCase):
Expand Down
Loading