Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Starlette middleware support #253

Merged
merged 3 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,25 @@ def __init__(
stream: bool = False,
spec: Optional[LitSpec] = None,
max_payload_size=None,
middlewares: Optional[list[tuple[Callable, dict]]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
):
if batch_timeout > timeout and timeout not in (False, -1):
raise ValueError("batch_timeout must be less than timeout")
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")
if isinstance(spec, OpenAISpec):
stream = True

if middlewares is None:
middlewares = []
if not isinstance(middlewares, list):
_msg = (
"middlewares must be a list of tuples"
" where each tuple contains a middleware and its arguments. For example:\n"
"server = ls.LitServer(ls.examples.SimpleLitAPI(), "
'middlewares=[(RequestIdMiddleware, {"length": 5})])'
)
raise ValueError(_msg)

if not api_path.startswith("/"):
raise ValueError(
Expand Down Expand Up @@ -364,8 +373,12 @@ async def stream_predict(request: self.request_type) -> self.response_type:
path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())]
)

for middleware, kwargs in self.middlewares:
self.app.add_middleware(middleware, **kwargs)
for middleware in self.middlewares:
if isinstance(middleware, tuple):
middleware, kwargs = middleware
self.app.add_middleware(middleware, **kwargs)
elif callable(middleware):
self.app.add_middleware(middleware)

@staticmethod
def generate_client_file():
Expand Down
36 changes: 36 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torch.nn as nn
from httpx import AsyncClient
from litserve.utils import wrap_litserve_start
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware

from unittest.mock import patch, MagicMock
import pytest
Expand Down Expand Up @@ -393,3 +395,37 @@ def test_custom_middleware():
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
assert response.json() == {"output": 16.0}, "server didn't return expected output"
assert response.headers["X-Request-Id"] == "00000"


def test_starlette_middlewares():
middlewares = [
(
TrustedHostMiddleware,
{
"allowed_hosts": ["localhost", "127.0.0.1"],
},
),
HTTPSRedirectMiddleware,
]
server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=middlewares)
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0}, headers={"Host": "localhost"})
assert response.status_code == 200, f"Expected response to be 200 but got {response.status_code}"
assert response.json() == {"output": 16.0}, "server didn't return expected output"

response = client.post("/predict", json={"input": 4.0}, headers={"Host": "not-trusted-host"})
assert response.status_code == 400, f"Expected response to be 400 but got {response.status_code}"


def test_middlewares_inputs():
server = ls.LitServer(SimpleLitAPI(), middlewares=[])
assert len(server.middlewares) == 1, "Default middleware should be present"

server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=[], max_payload_size=1000)
assert len(server.middlewares) == 2, "Default middleware should be present"

server = ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=None)
assert len(server.middlewares) == 1, "Default middleware should be present"

with pytest.raises(ValueError, match="middlewares must be a list of tuples"):
ls.LitServer(ls.examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))
Loading