diff --git a/src/litserve/server.py b/src/litserve/server.py index ff913e02..3d69ef92 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -113,7 +113,7 @@ def __init__( stream: bool = False, spec: Optional[LitSpec] = None, max_payload_size=None, - middlewares: Optional[list[tuple[Callable, dict]]] = None, + middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None, ): if batch_timeout > timeout and timeout not in (False, -1): raise ValueError("batch_timeout must be less than timeout") @@ -121,8 +121,17 @@ def __init__( raise ValueError("max_batch_size must be greater than 0") if isinstance(spec, OpenAISpec): stream = True + if middlewares is None: middlewares = [] + if not isinstance(middlewares, list): + _msg = ( + "middlewares must be a list of tuples" + " where each tuple contains a middleware and its arguments. For example:\n" + "server = ls.LitServer(ls.examples.SimpleLitAPI(), " + 'middlewares=[(RequestIdMiddleware, {"length": 5})])' + ) + raise ValueError(_msg) if not api_path.startswith("/"): raise ValueError( @@ -364,8 +373,12 @@ async def stream_predict(request: self.request_type) -> self.response_type: path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())] ) - for middleware, kwargs in self.middlewares: - self.app.add_middleware(middleware, **kwargs) + for middleware in self.middlewares: + if isinstance(middleware, tuple): + middleware, kwargs = middleware + self.app.add_middleware(middleware, **kwargs) + elif callable(middleware): + self.app.add_middleware(middleware) @staticmethod def generate_client_file(): diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index e14bbbc3..c61d2ed0 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -23,6 +23,8 @@ import torch.nn as nn from httpx import AsyncClient from litserve.utils import wrap_litserve_start +from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware +from starlette.middleware.trustedhost import TrustedHostMiddleware from unittest.mock import patch, MagicMock import pytest @@ -393,3 +395,37 @@ def test_custom_middleware(): assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}" assert response.json() == {"output": 16.0}, "server didn't return expected output" assert response.headers["X-Request-Id"] == "00000" + + +def test_starlette_middlewares(): + middlewares = [ + ( + TrustedHostMiddleware, + { + "allowed_hosts": ["localhost", "127.0.0.1"], + }, + ), + HTTPSRedirectMiddleware, + ] + server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=middlewares) + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}, headers={"Host": "localhost"}) + assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}" + assert response.json() == {"output": 16.0}, "server didn't return expected output" + + response = client.post("/predict", json={"input": 4.0}, headers={"Host": "not-trusted-host"}) + assert response.status_code == 400, f"Expected response to be 400 but got {response.status_code}" + + +def test_middlewares_inputs(): + server = ls.LitServer(SimpleLitAPI(), middlewares=[]) + assert len(server.middlewares) == 1, "Default middleware should be present" + + server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=[], max_payload_size=1000) + assert len(server.middlewares) == 2, "Default middleware should be present" + + server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=None) + assert len(server.middlewares) == 1, "Default middleware should be present" + + with pytest.raises(ValueError, match="middlewares must be a list of tuples"): + ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))