From 41a43d9ce427424e3ae3cc834872b4e771d20f41 Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Wed, 22 Jan 2025 11:30:48 -0800 Subject: [PATCH 1/2] [PERF] Refactor frontend middleware --- chromadb/server/fastapi/__init__.py | 76 +++++++++++++++-------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 4d0e05ca716..174d43aba20 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -64,6 +64,8 @@ UpdateEmbedding, ) from starlette.datastructures import Headers +from starlette.types import Receive, Send, Scope +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint import logging from chromadb.telemetry.product.events import ServerStartEvent @@ -104,45 +106,45 @@ def use_route_names_as_operation_ids(app: _FastAPI) -> None: route.operation_id = route.name -async def add_trace_id_to_response_middleware( - request: Request, call_next: Callable[[Request], Any] -) -> Response: - trace_id = trace.get_current_span().get_span_context().trace_id - response = await call_next(request) - response.headers["Chroma-Trace-Id"] = format(trace_id, "x") - return response +class AddTraceIdInResponseHeaderMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + trace_id = trace.get_current_span().get_span_context().trace_id + response = await call_next(request) + response.headers["Chroma-Trace-Id"] = format(trace_id, "x") + return response -async def catch_exceptions_middleware( - request: Request, call_next: Callable[[Request], Any] -) -> Response: - try: - return await call_next(request) - except ChromaError as e: - return fastapi_json_response(e) - except ValueError as e: - return ORJSONResponse( - content={"error": "InvalidArgumentError", "message": str(e)}, - status_code=400, - ) - except TypeError as e: - return ORJSONResponse( - content={"error": "InvalidArgumentError", "message": str(e)}, - status_code=400, - ) - except Exception as e: - logger.exception(e) - return ORJSONResponse(content={"error": repr(e)}, status_code=500) +class CatchExceptionsMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + try: + return await call_next(request) + except ChromaError as e: + return fastapi_json_response(e) + except ValueError as e: + return ORJSONResponse( + content={"error": "InvalidArgumentError", "message": str(e)}, + status_code=400, + ) + except TypeError as e: + return ORJSONResponse( + content={"error": "InvalidArgumentError", "message": str(e)}, + status_code=400, + ) + except Exception as e: + logger.exception(e) + return ORJSONResponse(content={"error": repr(e)}, status_code=500) -async def check_http_version_middleware( - request: Request, call_next: Callable[[Request], Any] -) -> Response: - http_version = request.scope.get("http_version") - if http_version not in ["1.1", "2"]: - raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported") - return await call_next(request) +class CheckHttpVersionMiddleware: + def __init__(self, app: fastapi.FastAPI): + self._app = app + async def __call__(self, scope: Scope, receive: Receive, send: Send): + http_version = scope.get("http_version") + if http_version not in ["1.1", "2"]: + raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported") + await self._app(scope, receive, send) + D = TypeVar("D", bound=BaseModel, contravariant=True) @@ -201,9 +203,9 @@ def __init__(self, settings: Settings): self._quota_enforcer = self._system.require(QuotaEnforcer) self._system.start() - self._app.middleware("http")(check_http_version_middleware) - self._app.middleware("http")(catch_exceptions_middleware) - self._app.middleware("http")(add_trace_id_to_response_middleware) + self._app.add_middleware(CheckHttpVersionMiddleware) + self._app.add_middleware(CatchExceptionsMiddleware) + self._app.add_middleware(AddTraceIdInResponseHeaderMiddleware) self._app.add_middleware( CORSMiddleware, allow_headers=["*"], From 0fd726041c6b8877e7a50bc8dc30a0bad41e639f Mon Sep 17 00:00:00 2001 From: Sicheng Pan Date: Wed, 22 Jan 2025 11:35:56 -0800 Subject: [PATCH 2/2] Fix check http version middleware --- chromadb/server/fastapi/__init__.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 174d43aba20..e4dc899ed61 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -135,16 +135,13 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): return ORJSONResponse(content={"error": repr(e)}, status_code=500) -class CheckHttpVersionMiddleware: - def __init__(self, app: fastapi.FastAPI): - self._app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send): - http_version = scope.get("http_version") +class CheckHttpVersionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + http_version = request.scope.get("http_version") if http_version not in ["1.1", "2"]: raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported") - await self._app(scope, receive, send) - + return await call_next(request) + D = TypeVar("D", bound=BaseModel, contravariant=True)