Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mock server #126

Merged
merged 11 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ dev = [
"coverage >= 7.4.1",
"requests >= 2.31.0",
"types-requests >= 2.31.0.20240125",
"docopt >= 0.6.2",
"types-docopt >= 0.6.11.4",
"uvicorn >= 0.28.0"
]

docs = [
Expand Down
77 changes: 77 additions & 0 deletions src/dispatch/test/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Mock Dispatch server for use in test environments.

Usage:
dispatch.test <endpoint> [--api-key=<key>] [--hostname=<name>] [--port=<port>] [-v | --verbose]
dispatch.test -h | --help

Options:
--api-key=<key> API key to require when clients connect to the server [default: test].

--hostname=<name> Hostname to listen on [default: 127.0.0.1].
--port=<port> Port to listen on [default: 4450].

-v --verbose Show verbose details in the log.
-h --help Show this help information.
"""

import base64
import logging
import sys

from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from docopt import docopt

from dispatch.test import DispatchServer, DispatchService, EndpointClient


def main():
args = docopt(__doc__)

if args["--help"]:
print(__doc__)
exit(0)

endpoint = args["<endpoint>"]
api_key = args["--api-key"]
hostname = args["--hostname"]
port_str = args["--port"]

try:
port = int(port_str)
except ValueError:
print(f"error: invalid port: {port_str}", file=sys.stderr)
exit(1)

# This private key was generated randomly.
signing_key = Ed25519PrivateKey.from_private_bytes(
b"\x0e\xca\xfb\xc9\xa9Gc'fR\xe4\x97y\xf0\xae\x90\x01\xe8\xd9\x94\xa6\xd4@\xf6\xa7!\x90b\\!z!"
)
verification_key = base64.b64encode(
signing_key.public_key().public_bytes_raw()
).decode()

log_level = logging.DEBUG if args["--verbose"] else logging.INFO
logging.basicConfig(
level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

endpoint_client = EndpointClient.from_url(endpoint, signing_key=signing_key)

with DispatchService(endpoint_client, api_key=api_key) as service:
with DispatchServer(service, hostname=hostname, port=port) as server:
print(f"Spawned a mock Dispatch server on {hostname}:{port} to dispatch")
print(f"function calls to the endpoint at {endpoint}.")
print()
print("The Dispatch SDK can be configured with:")
print()
print(f' export DISPATCH_API_URL="http://{hostname}:{port}"')
print(f' export DISPATCH_API_KEY="{api_key}"')
print(f' export DISPATCH_ENDPOINT_URL="{endpoint}"')
print(f' export DISPATCH_VERIFICATION_KEY="{verification_key}"')
print()

server.wait()


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions src/dispatch/test/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def start(self):
"""Start the server."""
self._server.start()

def wait(self):
"""Block until the server terminates."""
self._server.wait_for_termination()

def stop(self):
"""Stop the server."""
self._server.stop(0)
Expand Down
50 changes: 37 additions & 13 deletions src/dispatch/test/service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import os
import threading
import time
from collections import OrderedDict
from dataclasses import dataclass
from typing import TypeAlias

import grpc
import httpx

import dispatch.sdk.v1.call_pb2 as call_pb
import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb
Expand Down Expand Up @@ -93,9 +95,7 @@ def Dispatch(self, request: dispatch_pb.DispatchRequest, context):
with self._work_signal:
for call in request.calls:
dispatch_id = self._make_dispatch_id()
logger.debug(
"enqueueing call to function %s as %s", call.function, dispatch_id
)
logger.debug("enqueueing call to function: %s", call.function)
resp.dispatch_ids.append(dispatch_id)
run_request = function_pb.RunRequest(
function=call.function,
Expand All @@ -113,6 +113,9 @@ def _validate_authentication(self, context: grpc.ServicerContext):
if key == "authorization":
if value == expected:
return
logger.warning(
"a client attempted to dispatch a function call with an incorrect API key. Is the client's DISPATCH_API_KEY correct?"
)
context.abort(
grpc.StatusCode.UNAUTHENTICATED,
f"Invalid authorization header. Expected '{expected}', got {value!r}",
Expand All @@ -131,11 +134,14 @@ def dispatch_calls(self):
while self.queue:
dispatch_id, request = self.queue.pop(0)

logger.debug(
"dispatching call to function %s (%s)", request.function, dispatch_id
)
logger.info("dispatching call to function: %s", request.function)

response = self.endpoint_client.run(request)
try:
response = self.endpoint_client.run(request)
except:
self.queue.extend(_next_queue)
self.queue.append((dispatch_id, request)) # retry
raise

if self.roundtrips is not None:
try:
Expand All @@ -147,9 +153,7 @@ def dispatch_calls(self):
self.roundtrips[dispatch_id] = roundtrips

if Status(response.status) in self.retry_on_status:
logger.debug(
"retrying call to function %s (%s)", request.function, dispatch_id
)
logger.info("retrying call to function: %s", request.function)
_next_queue.append((dispatch_id, request))

elif response.HasField("poll"):
Expand Down Expand Up @@ -182,9 +186,8 @@ def dispatch_calls(self):
if response.exit.HasField("tail_call"):
tail_call = response.exit.tail_call
logger.debug(
"enqueueing tail call to %s (%s)",
"enqueueing tail call to function: %s",
tail_call.function,
dispatch_id,
)
tail_call_request = function_pb.RunRequest(
function=tail_call.function,
Expand Down Expand Up @@ -259,7 +262,28 @@ def _dispatch_continuously(self):
if self._stop_event.is_set():
break

self.dispatch_calls()
ok = False
try:
self.dispatch_calls()
except httpx.HTTPStatusError as e:
if e.response.status_code == 403:
logger.error(
"error dispatching function call to endpoint (403). Is the endpoint's DISPATCH_VERIFICATION_KEY correct?"
)
else:
logger.exception(e)
except httpx.ConnectError as e:
logger.error(
"error connecting to the endpoint. Is it running and accessible from DISPATCH_ENDPOINT_URL?"
)
except Exception as e:
logger.exception(e)
else:
ok = True
if not ok:
# Introduce an artificial delay between errors to
# avoid busy-loops.
time.sleep(1.0)

def __enter__(self):
self.start()
Expand Down