Skip to content

dispatch.test package #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 22, 2024
Merged
51 changes: 27 additions & 24 deletions examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from fastapi.testclient import TestClient

import dispatch.sdk.v1.status_pb2 as status_pb

from ... import function_service
from ...test_client import ServerTest
from dispatch import Client
from dispatch.sdk.v1 import status_pb2 as status_pb
from dispatch.test import DispatchServer, DispatchService, EndpointClient


class TestAutoRetry(unittest.TestCase):
Expand All @@ -22,29 +21,33 @@ class TestAutoRetry(unittest.TestCase):
"DISPATCH_API_KEY": "0000000000000000",
},
)
def test_foo(self):
from . import app
def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:

server = ServerTest()
servicer = server.servicer
app.dispatch._client = server.client
app.some_logic._client = server.client
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app.app, base_url="http://dispatch-service")
app_client = function_service.client(http_client)
http_client = TestClient(app)
response = http_client.get("/")
self.assertEqual(response.status_code, 200)

response = http_client.get("/")
self.assertEqual(response.status_code, 200)
dispatch_service.dispatch_calls()

server.execute(app_client)
# Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6
# calls, including 5 retries.
for i in range(6):
dispatch_service.dispatch_calls()

# Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6
# calls, including 5 retries.
for i in range(6):
server.execute(app_client)
self.assertEqual(len(servicer.responses), 6)
self.assertEqual(len(dispatch_service.roundtrips), 1)
roundtrips = list(dispatch_service.roundtrips.values())[0]
self.assertEqual(len(roundtrips), 6)

statuses = [r["response"].status for r in servicer.responses]
self.assertEqual(
statuses, [status_pb.STATUS_TEMPORARY_ERROR] * 5 + [status_pb.STATUS_OK]
)
statuses = [response.status for request, response in roundtrips]
self.assertEqual(
statuses, [status_pb.STATUS_TEMPORARY_ERROR] * 5 + [status_pb.STATUS_OK]
)
31 changes: 17 additions & 14 deletions examples/getting_started/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from fastapi.testclient import TestClient

from ... import function_service
from ...test_client import ServerTest
from dispatch import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient


class TestGettingStarted(unittest.TestCase):
Expand All @@ -20,20 +20,23 @@ class TestGettingStarted(unittest.TestCase):
"DISPATCH_API_KEY": "0000000000000000",
},
)
def test_foo(self):
from . import app
def test_app(self):
from .app import app, dispatch

server = ServerTest()
servicer = server.servicer
app.dispatch._client = server.client
app.publish._client = server.client
# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:

http_client = TestClient(app.app, base_url="http://dispatch-service")
app_client = function_service.client(http_client)
# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

response = http_client.get("/")
self.assertEqual(response.status_code, 200)
http_client = TestClient(app)
response = http_client.get("/")
self.assertEqual(response.status_code, 200)

server.execute(app_client)
dispatch_service.dispatch_calls()

self.assertEqual(len(servicer.responses), 1)
self.assertEqual(len(dispatch_service.roundtrips), 1) # one call submitted
dispatch_id, roundtrips = list(dispatch_service.roundtrips.items())[0]
self.assertEqual(len(roundtrips), 1) # one roundtrip for this call
55 changes: 34 additions & 21 deletions examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from fastapi.testclient import TestClient

from ... import function_service
from ...test_client import ServerTest
from dispatch.client import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient


class TestGithubStats(unittest.TestCase):
Expand All @@ -20,22 +20,35 @@ class TestGithubStats(unittest.TestCase):
"DISPATCH_API_KEY": "0000000000000000",
},
)
def test_foo(self):
from . import app

server = ServerTest()
servicer = server.servicer
app.dispatch._client = server.client
app.get_repo_info._client = server.client
app.get_contributors._client = server.client
app.main._client = server.client

http_client = TestClient(app.app, base_url="http://dispatch-service")
app_client = function_service.client(http_client)

response = http_client.get("/")
self.assertEqual(response.status_code, 200)

server.execute(app_client)

self.assertEqual(len(servicer.responses), 1)
def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:

# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

http_client = TestClient(app)
response = http_client.get("/")
self.assertEqual(response.status_code, 200)

while dispatch_service.queue:
dispatch_service.dispatch_calls()

# Three unique functions were called, with five total round-trips.
# The main function is called initially, and then polls
# twice, for three total round-trips. There's one round-trip
# to get_repo_info and one round-trip to get_contributors.
self.assertEqual(
3, len(dispatch_service.roundtrips)
) # 3 unique functions were called
self.assertEqual(
5,
sum(
len(roundtrips)
for roundtrips in dispatch_service.roundtrips.values()
),
)
2 changes: 1 addition & 1 deletion src/dispatch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _init_stub(self):

self._stub = dispatch_grpc.DispatchServiceStub(channel)

def dispatch(self, calls: Iterable[Call]) -> Iterable[DispatchID]:
def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]:
"""Dispatch function calls.

Args:
Expand Down
19 changes: 10 additions & 9 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Function:
def __init__(
self,
endpoint: str,
client: Client | None,
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable,
Expand Down Expand Up @@ -102,11 +102,6 @@ def dispatch(self, *args: Any, **kwargs: Any) -> DispatchID:
return self._primitive_dispatch(Arguments(args, kwargs))

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
if self._client is None:
raise RuntimeError(
"Dispatch Client has not been configured (api_key not provided)"
)

[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
return dispatch_id

Expand Down Expand Up @@ -151,13 +146,13 @@ class Registry:

__slots__ = ("_functions", "_endpoint", "_client")

def __init__(self, endpoint: str, client: Client | None):
def __init__(self, endpoint: str, client: Client):
"""Initialize a local function registry.

Args:
endpoint: URL of the endpoint that the function is accessible from.
client: Optional client for the Dispatch API. If provided, calls
to local functions can be dispatched directly.
client: Client for the Dispatch API. Used to dispatch calls to
local functions.
"""
self._functions: Dict[str, Function] = {}
self._endpoint = endpoint
Expand Down Expand Up @@ -235,3 +230,9 @@ def _register(
)
self._functions[name] = wrapped_func
return wrapped_func

def set_client(self, client: Client):
"""Set the Client instance used to dispatch calls to local functions."""
self._client = client
for fn in self._functions.values():
fn._client = client
5 changes: 5 additions & 0 deletions src/dispatch/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .client import EndpointClient
from .server import DispatchServer
from .service import DispatchService

__all__ = ["EndpointClient", "DispatchServer", "DispatchService"]
Loading