Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PERF] Refactor frontend middleware #3535

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 36 additions & 37 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
UpdateEmbedding,
)
from starlette.datastructures import Headers
from starlette.types import Receive, Send, Scope
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
import logging

from chromadb.telemetry.product.events import ServerStartEvent
Expand Down Expand Up @@ -104,44 +106,41 @@ def use_route_names_as_operation_ids(app: _FastAPI) -> None:
route.operation_id = route.name


async def add_trace_id_to_response_middleware(
request: Request, call_next: Callable[[Request], Any]
) -> Response:
trace_id = trace.get_current_span().get_span_context().trace_id
response = await call_next(request)
response.headers["Chroma-Trace-Id"] = format(trace_id, "x")
return response
class AddTraceIdInResponseHeaderMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
trace_id = trace.get_current_span().get_span_context().trace_id
response = await call_next(request)
response.headers["Chroma-Trace-Id"] = format(trace_id, "x")
return response


async def catch_exceptions_middleware(
request: Request, call_next: Callable[[Request], Any]
) -> Response:
try:
return await call_next(request)
except ChromaError as e:
return fastapi_json_response(e)
except ValueError as e:
return ORJSONResponse(
content={"error": "InvalidArgumentError", "message": str(e)},
status_code=400,
)
except TypeError as e:
return ORJSONResponse(
content={"error": "InvalidArgumentError", "message": str(e)},
status_code=400,
)
except Exception as e:
logger.exception(e)
return ORJSONResponse(content={"error": repr(e)}, status_code=500)
class CatchExceptionsMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
try:
return await call_next(request)
except ChromaError as e:
return fastapi_json_response(e)
except ValueError as e:
return ORJSONResponse(
content={"error": "InvalidArgumentError", "message": str(e)},
status_code=400,
)
except TypeError as e:
return ORJSONResponse(
content={"error": "InvalidArgumentError", "message": str(e)},
status_code=400,
)
except Exception as e:
logger.exception(e)
return ORJSONResponse(content={"error": repr(e)}, status_code=500)


async def check_http_version_middleware(
request: Request, call_next: Callable[[Request], Any]
) -> Response:
http_version = request.scope.get("http_version")
if http_version not in ["1.1", "2"]:
raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported")
return await call_next(request)
class CheckHttpVersionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
http_version = request.scope.get("http_version")
if http_version not in ["1.1", "2"]:
raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported")
return await call_next(request)


D = TypeVar("D", bound=BaseModel, contravariant=True)
Expand Down Expand Up @@ -201,9 +200,9 @@ def __init__(self, settings: Settings):
self._quota_enforcer = self._system.require(QuotaEnforcer)
self._system.start()

self._app.middleware("http")(check_http_version_middleware)
self._app.middleware("http")(catch_exceptions_middleware)
self._app.middleware("http")(add_trace_id_to_response_middleware)
self._app.add_middleware(CheckHttpVersionMiddleware)
self._app.add_middleware(CatchExceptionsMiddleware)
self._app.add_middleware(AddTraceIdInResponseHeaderMiddleware)
self._app.add_middleware(
CORSMiddleware,
allow_headers=["*"],
Expand Down
Loading