diff --git a/README.md b/README.md index 8591353..8b7d3f8 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Python package to develop applications with the Dispatch platform. - [Running Dispatch Applications](#running-dispatch-applications) - [Writing Transactional Applications with Dispatch](#writing-transactional-applications-with-dispatch) - [Integration with FastAPI](#integration-with-fastapi) + - [Integration with Flask](#integration-with-flask) - [Configuration](#configuration) - [Serialization](#serialization) - [Examples](#examples) @@ -198,6 +199,22 @@ In this example, GET requests on the HTTP server dispatch calls to the `publish` function. The function runs concurrently to the rest of the program, driven by the Dispatch SDK. +### Integration with Flask + +Dispatch can also be integrated with web applications built on [Flask][flask]. + +The API is nearly identical to FastAPI above, instead use: + +```python +from flask import Flask +from dispatch.flask import Dispatch + +app = Flask(__name__) +dispatch = Dispatch(app) +``` + +[flask]: https://flask.palletsprojects.com/en/3.0.x/ + ### Configuration The Dispatch CLI automatically configures the SDK, so manual configuration is diff --git a/src/dispatch/test/flask.py b/src/dispatch/test/flask.py new file mode 100644 index 0000000..e8cc3cb --- /dev/null +++ b/src/dispatch/test/flask.py @@ -0,0 +1,46 @@ +from typing import Mapping + +import werkzeug.test +from flask import Flask + +from dispatch.test.http import HttpClient, HttpResponse + + +def http_client(app: Flask) -> HttpClient: + """Build a client for a Flask app.""" + return Client(app.test_client()) + + +class Client(HttpClient): + def __init__(self, client: werkzeug.test.Client): + self.client = client + + def get(self, url: str, headers: Mapping[str, str] = {}) -> HttpResponse: + response = self.client.get(url, headers=headers.items()) + return Response(response) + + def post( + self, url: str, body: bytes, headers: Mapping[str, str] = {} + ) -> HttpResponse: + response = self.client.post(url, data=body, headers=headers.items()) + return Response(response) + + def url_for(self, path: str) -> str: + return "http://localhost" + path + + +class Response(HttpResponse): + def __init__(self, response): + self.response = response + + @property + def status_code(self): + return self.response.status_code + + @property + def body(self): + return self.response.data + + def raise_for_status(self): + if self.response.status_code // 100 != 2: + raise RuntimeError(f"HTTP status code {self.response.status_code}") diff --git a/tests/test_flask.py b/tests/test_flask.py new file mode 100644 index 0000000..dbf5932 --- /dev/null +++ b/tests/test_flask.py @@ -0,0 +1,143 @@ +import base64 +import os +import pickle +import struct +import unittest +from typing import Any, Optional +from unittest import mock + +import google.protobuf.any_pb2 +import google.protobuf.wrappers_pb2 +from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PrivateKey, + Ed25519PublicKey, +) +from flask import Flask + +import dispatch +from dispatch.experimental.durable.registry import clear_functions +from dispatch.flask import Dispatch +from dispatch.function import Arguments, Error, Function, Input, Output +from dispatch.proto import _any_unpickle as any_unpickle +from dispatch.sdk.v1 import call_pb2 as call_pb +from dispatch.sdk.v1 import function_pb2 as function_pb +from dispatch.signature import ( + parse_verification_key, + private_key_from_pem, + public_key_from_pem, +) +from dispatch.status import Status +from dispatch.test import DispatchServer, DispatchService, EndpointClient +from dispatch.test.flask import http_client + + +def create_dispatch_instance(app: Flask, endpoint: str): + return Dispatch( + app, + endpoint=endpoint, + api_key="0000000000000000", + api_url="http://127.0.0.1:10000", + ) + + +def create_endpoint_client(app: Flask, signing_key: Optional[Ed25519PrivateKey] = None): + return EndpointClient(http_client(app), signing_key) + + +class TestFlask(unittest.TestCase): + def test_flask(self): + app = Flask(__name__) + dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/") + + @dispatch.primitive_function + def my_function(input: Input) -> Output: + return Output.value( + f"You told me: '{input.input}' ({len(input.input)} characters)" + ) + + client = create_endpoint_client(app) + pickled = pickle.dumps("Hello World!") + input_any = google.protobuf.any_pb2.Any() + input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled)) + + req = function_pb.RunRequest( + function=my_function.name, + input=input_any, + ) + + resp = client.run(req) + + self.assertIsInstance(resp, function_pb.RunResponse) + + resp.exit.result.output.Unpack( + output_bytes := google.protobuf.wrappers_pb2.BytesValue() + ) + output = pickle.loads(output_bytes.value) + + self.assertEqual(output, "You told me: 'Hello World!' (12 characters)") + + +signing_key = private_key_from_pem( + """ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF +-----END PRIVATE KEY----- +""" +) + +verification_key = public_key_from_pem( + """ +-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs= +-----END PUBLIC KEY----- +""" +) + + +class TestFlaskE2E(unittest.TestCase): + def setUp(self): + self.endpoint_app = Flask(__name__) + endpoint_client = create_endpoint_client(self.endpoint_app, signing_key) + + api_key = "0000000000000000" + self.dispatch_service = DispatchService( + endpoint_client, api_key, collect_roundtrips=True + ) + self.dispatch_server = DispatchServer(self.dispatch_service) + self.dispatch_client = dispatch.Client( + api_key, api_url=self.dispatch_server.url + ) + + self.dispatch = Dispatch( + self.endpoint_app, + endpoint="http://function-service", # unused + verification_key=verification_key, + api_key=api_key, + api_url=self.dispatch_server.url, + ) + + self.dispatch_server.start() + + def tearDown(self): + self.dispatch_server.stop() + + def test_simple_end_to_end(self): + # The Flask server. + @self.dispatch.function + def my_function(name: str) -> str: + return f"Hello world: {name}" + + call = my_function.build_call(52) + self.assertEqual(call.function.split(".")[-1], "my_function") + + # The client. + [dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)]) + + # Simulate execution for testing purposes. + self.dispatch_service.dispatch_calls() + + # Validate results. + roundtrips = self.dispatch_service.roundtrips[dispatch_id] + self.assertEqual(len(roundtrips), 1) + _, response = roundtrips[0] + self.assertEqual(any_unpickle(response.exit.result.output), "Hello world: 52")