Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jun 14, 2022
1 parent d192a6b commit a0aa64a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 62 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_long_description():
install_requires=[
"anyio>=3.4.0,<5",
"typing_extensions>=3.10.0; python_version < '3.10'",
"async_generator; python < '3.10'",
"async_generator; python_version < '3.10'",
],
extras_require={
"full": [
Expand Down
63 changes: 26 additions & 37 deletions starlette/middleware/http.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from functools import partial
from typing import AsyncGenerator, Callable, Optional
from typing import AsyncGenerator, Callable, Optional, Union

from .._compat import aclosing
from ..datastructures import MutableHeaders
from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send

HTTPDispatchFlow = AsyncGenerator[Optional[Response], Response]
_HTTPDispatchFlow = Union[
AsyncGenerator[None, Response], AsyncGenerator[Response, Response]
]


class HTTPMiddleware:
def __init__(
self,
app: ASGIApp,
dispatch_func: Optional[Callable[[Scope], HTTPDispatchFlow]] = None,
dispatch_func: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func

def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
def dispatch(
self, scope: Scope
) -> Union[AsyncGenerator[None, Response], AsyncGenerator[Response, Response]]:
raise NotImplementedError # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand All @@ -37,12 +40,24 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

response_started = set[bool]()

wrapped_send = partial(
self._send,
flow=flow,
response_started=response_started,
send=send,
)
async def wrapped_send(message: Message) -> None:
if message["type"] == "http.response.start":
response_started.add(True)

response = Response(status_code=message["status"])
response.raw_headers.clear()

try:
await flow.asend(response)
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")

headers = MutableHeaders(raw=message["headers"])
headers.update(response.headers)

await send(message)

try:
await self.app(scope, receive, wrapped_send)
Expand All @@ -66,29 +81,3 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

if not response_started:
raise RuntimeError("No response returned.")

async def _send(
self,
message: Message,
*,
flow: HTTPDispatchFlow,
response_started: set,
send: Send,
) -> None:
if message["type"] == "http.response.start":
response_started.add(True)

response = Response(status_code=message["status"])
response.raw_headers.clear()

try:
await flow.asend(response)
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")

headers = MutableHeaders(raw=message["headers"])
headers.update(response.headers)

await send(message)
39 changes: 15 additions & 24 deletions tests/middleware/test_http.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import contextvars
from typing import AsyncGenerator

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.http import HTTPDispatchFlow, HTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.middleware.http import HTTPMiddleware
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
response = yield None
async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
response = yield
response.headers["Custom-Header"] = "Example"


Expand Down Expand Up @@ -88,19 +89,19 @@ def test_state_data_across_multiple_middlewares(test_client_factory):
expected_value2 = "bar"

class aMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
scope["state_foo"] = expected_value1
yield None
yield

class bMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
scope["state_bar"] = expected_value2
response = yield None
response = yield
response.headers["X-State-Foo"] = scope["state_foo"]

class cMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
response = yield None
async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
response = yield
response.headers["X-State-Bar"] = scope["state_bar"]

def homepage(request):
Expand Down Expand Up @@ -143,7 +144,7 @@ def test_middleware_repr():
def test_fully_evaluated_response(test_client_factory):
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]:
yield PlainTextResponse("Custom")

app = Starlette(middleware=[Middleware(CustomMiddleware)])
Expand All @@ -153,16 +154,6 @@ async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
assert response.text == "Custom"


def test_exception_on_mounted_apps(test_client_factory):
sub_app = Starlette(routes=[Route("/", exc)])
app = Starlette(routes=[Mount("/sub", app=sub_app)])

client = test_client_factory(app)
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"


ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")


Expand All @@ -177,9 +168,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:


class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
ctxvar.set("set by middleware")
yield None
yield
assert ctxvar.get() == "set by endpoint"


Expand Down

0 comments on commit a0aa64a

Please sign in to comment.