Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment a high-level HTTPMiddleware #1691

Closed
wants to merge 13 commits into from
94 changes: 94 additions & 0 deletions starlette/middleware/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from contextlib import aclosing
from functools import partial
from typing import AsyncGenerator, Callable, Optional

from ..datastructures import MutableHeaders
from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send

HTTPDispatchFlow = AsyncGenerator[Optional[Response], Response]


class HTTPMiddleware:
def __init__(
self,
app: ASGIApp,
dispatch_func: Optional[Callable[[Scope], HTTPDispatchFlow]] = None,
) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func

def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
raise NotImplementedError # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

async with aclosing(self.dispatch(scope)) as flow:
# Kick the flow until the first `yield`.
# Might respond early before we call into the app.
early_response = await flow.__anext__()

if early_response is not None:
await early_response(scope, receive, send)
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
return

response_started = set[bool]()

wrapped_send = partial(
self._send,
flow=flow,
response_started=response_started,
send=send,
)
adriangb marked this conversation as resolved.
Show resolved Hide resolved

try:
await self.app(scope, receive, wrapped_send)
except Exception as exc:
if response_started:
raise

try:
response = await flow.athrow(exc)
except Exception:
# Exception was not handled, or they raised another one.
raise

if response is None:
raise RuntimeError(
f"dispatch() handled exception {exc!r}, "
"but no response was returned"
)

await response(scope, receive, send)

if not response_started:
raise RuntimeError("No response returned.")

async def _send(
self,
message: Message,
*,
flow: HTTPDispatchFlow,
response_started: set,
send: Send,
) -> None:
if message["type"] == "http.response.start":
response_started.add(True)

response = Response(status_code=message["status"])
response.raw_headers.clear()

try:
await flow.asend(response)
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")

headers = MutableHeaders(raw=message["headers"])
headers.update(response.headers)

await send(message)
208 changes: 208 additions & 0 deletions tests/middleware/test_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import contextvars

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.http import HTTPDispatchFlow, HTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
response = yield None
response.headers["Custom-Header"] = "Example"


def homepage(request):
return PlainTextResponse("Homepage")


def exc(request):
raise Exception("Exc")


def exc_stream(request):
return StreamingResponse(_generate_faulty_stream())


def _generate_faulty_stream():
yield b"Ok"
raise Exception("Faulty Stream")


class NoResponse:
def __init__(self, scope, receive, send):
pass

def __await__(self):
return self.dispatch().__await__()

async def dispatch(self):
pass


async def websocket_endpoint(session):
await session.accept()
await session.send_text("Hello, world!")
await session.close()


app = Starlette(
routes=[
Route("/", endpoint=homepage),
Route("/exc", endpoint=exc),
Route("/exc-stream", endpoint=exc_stream),
Route("/no-response", endpoint=NoResponse),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
],
middleware=[Middleware(CustomMiddleware)],
)


def test_custom_middleware(test_client_factory):
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

with pytest.raises(Exception) as ctx:
response = client.get("/exc")
assert str(ctx.value) == "Exc"

with pytest.raises(Exception) as ctx:
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"

with pytest.raises(RuntimeError):
response = client.get("/no-response")

with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"


def test_state_data_across_multiple_middlewares(test_client_factory):
expected_value1 = "foo"
expected_value2 = "bar"

class aMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
scope["state_foo"] = expected_value1
yield None

class bMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
scope["state_bar"] = expected_value2
response = yield None
response.headers["X-State-Foo"] = scope["state_foo"]

class cMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
response = yield None
response.headers["X-State-Bar"] = scope["state_bar"]

def homepage(request):
return PlainTextResponse("OK")

app = Starlette(
routes=[Route("/", homepage)],
middleware=[
Middleware(aMiddleware),
Middleware(bMiddleware),
Middleware(cMiddleware),
],
)

client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2


def test_app_middleware_argument(test_client_factory):
def homepage(request):
return PlainTextResponse("Homepage")

app = Starlette(
routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
)

client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"


def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"


def test_fully_evaluated_response(test_client_factory):
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
yield PlainTextResponse("Custom")

app = Starlette(middleware=[Middleware(CustomMiddleware)])

client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_exception_on_mounted_apps(test_client_factory):
sub_app = Starlette(routes=[Route("/", exc)])
app = Starlette(routes=[Mount("/sub", app=sub_app)])

client = test_client_factory(app)
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved


ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")


class CustomMiddlewareWithoutBaseHTTPMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
ctxvar.set("set by middleware")
await self.app(scope, receive, send)
assert ctxvar.get() == "set by endpoint"


class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
ctxvar.set("set by middleware")
yield None
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
assert ctxvar.get() == "set by endpoint"


@pytest.mark.parametrize(
"middleware_cls",
[
CustomMiddlewareWithoutBaseHTTPMiddleware,
CustomMiddlewareUsingHTTPMiddleware,
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
async def homepage(request):
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")

app = Starlette(
middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content