diff --git a/.github/workflows/ci_code.yml b/.github/workflows/ci_code.yml index 0711344a9..991d7b24c 100644 --- a/.github/workflows/ci_code.yml +++ b/.github/workflows/ci_code.yml @@ -78,6 +78,11 @@ jobs: make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/default.yaml make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/sql.yaml + - name: Rest Server Testing + run: | + make rest_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/default.yaml + + - name: Usecase Testing run: | make usecase_testing SUPERDUPER_CONFIG=test/configs/default.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index 232c03557..b01b3f653 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Qdrant vector search support - Add placeholder for web app link in Application - Add support for remote artifacts +- Add basic rest server #### Bug Fixes diff --git a/Makefile b/Makefile index 5b4c30e5f..45c5d4baa 100644 --- a/Makefile +++ b/Makefile @@ -98,6 +98,9 @@ fix-and-check: ## Lint the code before testing ##@ CI Testing Functions +rest_testing: ## Execute rest unit tests + SUPERDUPER_CONFIG=$(SUPERDUPER_CONFIG) pytest $(PYTEST_ARGUMENTS) ./test/rest + unit_testing: ## Execute unit testing SUPERDUPER_CONFIG=$(SUPERDUPER_CONFIG) pytest $(PYTEST_ARGUMENTS) ./test/unittest diff --git a/pyproject.toml b/pyproject.toml index b21b9afcc..6f37910d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "python-magic", "apscheduler", "bson", + "python-multipart>=0.0.9", "qdrant-client>=1.10.0,<2" ] diff --git a/superduper/rest/app.py b/superduper/rest/app.py new file mode 100644 index 000000000..7f1c39922 --- /dev/null +++ b/superduper/rest/app.py @@ -0,0 +1,22 @@ +import yaml + +from superduper import CFG, logging +from superduper.rest.base import SuperDuperApp +from superduper.rest.build import build_rest_app + +assert isinstance( + CFG.cluster.rest.uri, str +), "cluster.rest.uri should be set with a valid uri" +port = int(CFG.cluster.rest.uri.split(':')[-1]) + +if CFG.cluster.rest.config is not None: + try: + with open(CFG.cluster.rest.config) as f: + CONFIG = yaml.safe_load(f) + except FileNotFoundError: + logging.warn("cluster.rest.config should be set with a valid path") + CONFIG = {} + +app = SuperDuperApp('rest', port=port) + +build_rest_app(app) diff --git a/superduper/rest/base.py b/superduper/rest/base.py new file mode 100644 index 000000000..b2830014d --- /dev/null +++ b/superduper/rest/base.py @@ -0,0 +1,248 @@ +import sys +import threading +import time +import typing as t +from functools import cached_property +from traceback import format_exc + +import uvicorn +from fastapi import APIRouter, Depends, FastAPI, Request +from fastapi.exceptions import HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from prettytable import PrettyTable +from starlette.middleware.base import BaseHTTPMiddleware + +from superduper import logging +from superduper.base.build import build_datalayer +from superduper.base.config import Config +from superduper.base.datalayer import Datalayer + +# --------------- Create exception handler middleware----------------- + + +class ExceptionHandlerMiddleware(BaseHTTPMiddleware): + """Middleware to handle exceptions and log them.""" + + async def dispatch(self, request: Request, call_next): + """Dispatch the request and handle exceptions. + + :param request: request to dispatch + :param call_next: next call to make + """ + try: + return await call_next(request) + except Exception as e: + host = getattr(getattr(request, "client", None), "host", None) + port = getattr(getattr(request, "client", None), "port", None) + url = ( + f"{request.url.path}?{request.query_params}" + if request.query_params + else request.url.path + ) + exception_type, exception_value, _ = sys.exc_info() + exception_traceback = format_exc() + exception_name = getattr(exception_type, "__name__", None) + msg = f'{host}:{port} - "{request.method} {url}"\ + 500 Internal Server Error <{exception_name}:\ + {exception_value}>' + logging.exception(msg, e=e) + return JSONResponse( + status_code=500, + content={ + 'error': exception_name, + 'messages': msg, + 'traceback': exception_traceback, + }, + ) + + +class SuperDuperApp: + """A wrapper class for creating a fastapi application. + + The class provides a simple interface for creating a fastapi application + with custom endpoints. + + :param service: name of the service + :param port: port to run the service on + :param db: datalayer instance + """ + + def __init__( + self, service='rest', port=8000, db: Datalayer = None, prefix: str = '' + ): + if prefix and not prefix.startswith('/'): + prefix = f'/{prefix}' + + self.service = service + + self.port = port + + self.app_host = '0.0.0.0' + self._app = FastAPI( + root_path=prefix, + ) + + self.router = APIRouter() + + self._app.add_middleware(ExceptionHandlerMiddleware) + self._app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # You can specify a list of allowed origins here + allow_credentials=True, + allow_methods=[ + "GET", + "POST", + "PUT", + "DELETE", + ], # You can adjust these as per your needs + allow_headers=["*"], # You can specify allowed headers here + ) + self._db = db + + @cached_property + def app(self): + """Return the application instance.""" + self._app.include_router(self.router) + return self._app + + def raise_error(self, msg: str, code: int): + """Raise an error with the given message and code. + + :param msg: message to raise + :param code: code to raise + """ + raise HTTPException(code, detail=msg) + + @cached_property + def db(self) -> Datalayer: + """Return the database instance from the app state.""" + return self._app.state.pool + + def add(self, *args, method='post', **kwargs): + """Register an endpoint with this method. + + :param method: method to use + """ + + def decorator(function): + self.router.add_api_route( + *args, **kwargs, endpoint=function, methods=[method] + ) + + return decorator + + def add_default_endpoints(self): + """Add default endpoints to the application. + + - /health: Health check endpoint + - /handshake/config: Handshake endpoint + """ + + @self.router.get('/health') + def health(): + return {'status': 200} + + def print_routes(self): + """Print the routes of the application.""" + table = PrettyTable() + + # Define the table headers + table.field_names = ["Path", "Methods", "Function"] + + # Add rows to the table + for route in self._app.routes: + table.add_row([route.path, ", ".join(route.methods), route.name]) + + logging.info(f"Routes for '{self.service}' app: \n{table}") + + def pre_start(self, cfg: t.Union[Config, None] = None): + """Pre-start the application. + + :param cfg: Configurations to use + """ + self.add_default_endpoints() + self.startup(cfg=cfg) + assert self.app + + def run(self): + """Run the application.""" + uvicorn.run( + self._app, + host=self.app_host, + port=self.port, + ) + + def start(self): + """Start the application.""" + self.pre_start() + self.print_routes() + self.run() + + def startup( + self, + cfg: t.Union[Config, None] = None, + ): + """Startup the application. + + :param cfg: Configurations to use + """ + + @self._app.on_event('startup') + def startup_db_client(): + sys.path.append('./') + if self._db is None: + db = build_datalayer(cfg) + else: + db = self._db + self._app.state.pool = db + self._db = db + + return + + def shutdown(self, function: t.Union[t.Callable, None] = None): + """Shutdown the application. + + :param function: function to run on shutdown + """ + + @self._app.on_event('shutdown') + def shutdown_db_client(): + try: + self._app.state.pool.close() + except AttributeError: + raise Exception('Could not close the database properly') + + +def database(request: Request) -> Datalayer: + """Return the database instance from the app state. + + :param request: request object + """ + return request.app.state.pool + + +def DatalayerDependency(): + """Dependency for injecting datalayer instance into endpoint implementation.""" + return Depends(database) + + +class Server(uvicorn.Server): + """Custom server class.""" + + def install_signal_handlers(self): + """Install signal handlers.""" + pass + + def run_in_thread(self): + """Run the server in a separate thread.""" + self._thread = threading.Thread(target=self.run) + self._thread.start() + + while not self.started: + time.sleep(1e-3) + + def stop(self): + """Stop the server.""" + self.should_exit = True + self._thread.join() diff --git a/superduper/rest/build.py b/superduper/rest/build.py new file mode 100644 index 000000000..2ec5efd49 --- /dev/null +++ b/superduper/rest/build.py @@ -0,0 +1,134 @@ +import hashlib +import typing as t + +import magic +from fastapi import File, Response + +from superduper import logging +from superduper.backends.base.query import Query +from superduper.base.document import Document +from superduper.components.component import Component +from superduper.rest.base import SuperDuperApp + +from .utils import rewrite_artifacts + + +def build_rest_app(app: SuperDuperApp): + """ + Add the key endpoints to the FastAPI app. + + :param app: SuperDuperApp + """ + + @app.add('/db/artifact_store/put', method='put') + def db_artifact_store_put_bytes(raw: bytes = File(...)): + file_id = str(hashlib.sha1(raw).hexdigest()) + app.db.artifact_store.put_bytes(serialized=raw, file_id=file_id) + return {'file_id': file_id} + + @app.add('/db/artifact_store/get', method='get') + def db_artifact_store_get_bytes(file_id: str): + bytes = app.db.artifact_store.get_bytes(file_id=file_id) + media_type = magic.from_buffer(bytes, mime=True) + return Response(content=bytes, media_type=media_type) + + @app.add('/db/apply', method='post') + def db_apply(info: t.Dict): + if '_variables' in info: + assert {'_variables', 'identifier'}.issubset(info.keys()) + variables = info.pop('_variables') + for k in variables: + assert '<' not in variables[k] + assert '>' not in variables[k] + assert ' ' not in variables[k] + + identifier = info.pop('identifier') + template_name = info.pop('_template_name', None) + + component = Component.from_template( + identifier=identifier, + template_body=info, + template_name=template_name, + db=app.db, + **variables, + ) + app.db.apply(component) + return {'status': 'ok'} + component = Document.decode(info).unpack() + app.db.apply(component) + return {'status': 'ok'} + + @app.add('/db/show', method='get') + def db_show( + type_id: t.Optional[str] = None, + identifier: t.Optional[str] = None, + version: t.Optional[int] = None, + application: t.Optional[str] = None, + ): + if application is not None: + r = app.db.metadata.get_component('application', application) + return r['namespace'] + else: + return app.db.show( + type_id=type_id, + identifier=identifier, + version=version, + ) + + @app.add('/db/remove', method='post') + def db_remove(type_id: str, identifier: str): + app.db.remove(type_id=type_id, identifier=identifier, force=True) + return {'status': 'ok'} + + @app.add('/db/show_template', method='get') + def db_show_template(identifier: str, type_id: str = 'template'): + template = app.db.load(type_id=type_id, identifier=identifier) + return template.form_template + + @app.add('/db/metadata/show_jobs', method='get') + def db_metadata_show_jobs(type_id: str, identifier: t.Optional[str] = None): + return [ + r['job_id'] + for r in app.db.metadata.show_jobs( + type_id=type_id, component_identifier=identifier + ) + if 'job_id' in r + ] + + @app.add('/db/execute', method='post') + def db_execute( + query: t.Dict, + ): + if '_path' not in query: + plugin = app.db.databackend.type.__module__.split('.')[0] + query['_path'] = f'{plugin}.query.parse_query' + + q = Document.decode(query, db=app.db).unpack() + + logging.info('processing this query:') + logging.info(q) + + result = q.execute() + + if q.type in {'insert', 'delete', 'update'}: + return {'_base': [str(x) for x in result[0]]}, [] + + logging.warn(str(q)) + + if isinstance(result, Document): + result = [result] + + result = [rewrite_artifacts(r, db=app.db) for r in result] + result = [r.encode() for r in result] + blobs_keys = [list(r.pop_blobs().keys()) for r in result] + result = list(zip(result, blobs_keys)) + + if isinstance(q, Query): + for i, r in enumerate(result): + r = list(r) + if q.primary_id in r[0]: + r[0][q.primary_id] = str(r[0][q.primary_id]) + result[i] = tuple(r) + if 'score' in result[0][0]: + result = sorted(result, key=lambda x: -x[0]['score']) + return result diff --git a/superduper/rest/deployed_app.py b/superduper/rest/deployed_app.py new file mode 100644 index 000000000..648571bfa --- /dev/null +++ b/superduper/rest/deployed_app.py @@ -0,0 +1,4 @@ +from .app import app + +app.pre_start() +app.print_routes() diff --git a/superduper/rest/utils.py b/superduper/rest/utils.py new file mode 100644 index 000000000..ae25e9fb6 --- /dev/null +++ b/superduper/rest/utils.py @@ -0,0 +1,23 @@ +import inspect + +from superduper import Document +from superduper.components.datatype import Artifact, Encodable + + +def rewrite_artifacts(r, db): + """Helper function to rewrite artifacts.""" + if isinstance(r, Encodable): + kwargs = r.dict() + kwargs['datatype'].encodable = 'artifact' + blob = r._encode()[0] + db.artifact_store.put_bytes(blob, file_id=r.identifier) + init_args = inspect.signature(Artifact.__init__).parameters.keys() + kwargs = {k: v for k, v in kwargs.items() if k in init_args} + return Artifact(**kwargs) + if isinstance(r, Document): + return Document(rewrite_artifacts(dict(r), db=db)) + if isinstance(r, dict): + return {k: rewrite_artifacts(v, db=db) for k, v in r.items()} + if isinstance(r, list): + return [rewrite_artifacts(v, db=db) for v in r] + return r diff --git a/test/rest/__init__.py b/test/rest/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/rest/mock_client.py b/test/rest/mock_client.py new file mode 100644 index 000000000..0ea1c9038 --- /dev/null +++ b/test/rest/mock_client.py @@ -0,0 +1,58 @@ +import os +from urllib.parse import urlencode + +from superduper import CFG + +HOST = CFG.cluster.rest.uri +VERBOSE = os.environ.get('SUPERDUPER_VERBOSE', '1') + + +def make_params(params): + return '?' + urlencode(params) + + +def insert(client, data): + query = {'query': 'coll.insert_many(documents)', 'documents': data} + return client.post('/db/execute', json=query) + + +def apply(client, component): + return client.post('/db/apply', json=component) + + +def delete(client): + return client.post('/db/execute', json={'query': 'coll.delete_many({})'}) + + +def remove(client, type_id, identifier): + return client.post(f'/db/remove?type_id={type_id}&identifier={identifier}', json={}) + + +def setup(client): + from superduper.base.build import build_datalayer + + db = build_datalayer() + db.cfg.auto_schema = True + client.app.state.pool = db + data = [ + {"x": [1, 2, 3, 4, 5], "y": 'test'}, + {"x": [6, 7, 8, 9, 10], "y": 'test'}, + ] + insert(client, data) + return client + + +def teardown(client): + delete(client) + remove(client, 'datatype', 'image') + + +if __name__ == '__main__': + import sys + + if sys.argv[1] == 'setup': + setup() + elif sys.argv[1] == 'teardown': + teardown() + else: + raise NotImplementedError diff --git a/test/rest/test_rest.py b/test/rest/test_rest.py new file mode 100644 index 000000000..ce107f0a9 --- /dev/null +++ b/test/rest/test_rest.py @@ -0,0 +1,132 @@ +import json + +import pytest +from fastapi.testclient import TestClient + +from superduper import CFG +from superduper.base.document import Document + +CFG.auto_schema = True +CFG.cluster.rest.uri = 'localhost:8000' +from superduper.rest.deployed_app import app + +from .mock_client import setup as _setup, teardown + + +@pytest.fixture +def setup(): + client = TestClient(app._app) + yield _setup(client) + teardown(client) + + +def test_health(setup): + response = setup.get("/health") + assert response.status_code == 200 + + +def test_select_data(setup): + result = setup.post('/db/execute', json={'query': 'coll.find({}, {"_id": 0})'}) + result = json.loads(result.content) + if 'error' in result: + raise Exception(result['messages'] + result['traceback']) + print(result) + assert len(result) == 2 + + +CODE = """ +from superduper import code + +@code +def my_function(x): + return x + 1 +""" + + +def test_apply(setup): + m = { + '_builds': { + 'function_body': { + '_path': 'superduper.base.code.Code', + 'code': CODE, + }, + 'my_function': { + '_path': 'superduper.components.model.ObjectModel', + 'object': '?function_body', + 'identifier': 'my_function', + }, + }, + '_base': '?my_function', + } + + _ = setup.post( + '/db/apply', + json=m, + ) + + models = setup.get('/db/show', params={'type_id': 'model'}) + models = json.loads(models.content) + + assert models == ['my_function'] + + +@pytest.mark.skip +def test_insert_image(setup): + result = setup.put( + '/db/artifact_store/put', files={"raw": ("test/material/data/test.png")} + ) + result = json.loads(result.content) + + file_id = result['file_id'] + + query = { + '_path': 'superduper.backends.mongodb.query.parse_query', + 'query': 'coll.insert_one(documents[0])', + '_builds': { + 'image_type': { + '_path': 'superduper.ext.pillow.encoder.image_type', + 'encodable': 'artifact', + }, + 'my_artifact': { + '_path': 'superduper.components.datatype.LazyArtifact', + 'blob': f'&:blob:{file_id}', + 'datatype': "?image_type", + }, + }, + 'documents': [ + { + 'img': '?my_artifact', + } + ], + } + + result = setup.post( + '/db/execute', + json=query, + ) + + query = { + '_path': 'superduper.backends.mongodb.query.parse_query', + 'query': 'coll.find(documents[0], documents[1])', + 'documents': [{}, {'_id': 0}], + } + + result = setup.post( + '/db/execute', + json=query, + ) + + result = json.loads(result.content) + from superduper import superduper + + db = superduper() + + result = [Document.decode(r[0], db=db).unpack() for r in result] + + assert len(result) == 3 + + image_record = next(r for r in result if 'img' in r) + + from PIL.PngImagePlugin import PngImageFile + + assert isinstance(image_record['img'], PngImageFile)