Skip to content

Commit

Permalink
More fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jun 14, 2022
1 parent 4c48d3d commit b70f13d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
16 changes: 15 additions & 1 deletion starlette/middleware/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,22 @@
from ..types import ASGIApp, Message, Receive, Scope, Send

_DispatchFlow = Union[
# Default case:
# response = yield
AsyncGenerator[None, Response],
AsyncGenerator[Response, Response],
# A/ Early response:
# if condition:
# yield Response(...)
# return
# response = yield None
#
# or B/ Error handling:
# try:
# response = yield None
# except ...:
# yield Response(...)
# else:
# ...
AsyncGenerator[Optional[Response], Response],
]

Expand Down
37 changes: 25 additions & 12 deletions tests/middleware/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,28 +154,41 @@ def test_middleware_repr():


def test_early_response(test_client_factory):
async def index(request):
return PlainTextResponse("Hello, world!")

class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[Response, Response]:
yield Response(status_code=401)
) -> AsyncGenerator[Optional[Response], Response]:
if conn.headers.get("X-Early") == "true":
yield Response(status_code=401)
return

app = Starlette(middleware=[Middleware(CustomMiddleware)])
yield None

app = Starlette(
routes=[Route("/", index)],
middleware=[Middleware(CustomMiddleware)],
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world!"
response = client.get("/", headers={"X-Early": "true"})
assert response.status_code == 401


def test_too_many_yields(test_client_factory) -> None:
class BadMiddleware(HTTPMiddleware):
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
_ = yield
yield

app = Starlette(middleware=[Middleware(BadMiddleware)])
app = Starlette(middleware=[Middleware(CustomMiddleware)])

client = test_client_factory(app)
with pytest.raises(RuntimeError, match="should yield exactly once"):
Expand All @@ -189,7 +202,7 @@ class Failed(Exception):
async def failure(request):
raise Failed()

class ErrorMiddleware(HTTPMiddleware):
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[Optional[Response], Response]:
Expand All @@ -200,7 +213,7 @@ async def dispatch(

app = Starlette(
routes=[Route("/fail", failure)],
middleware=[Middleware(ErrorMiddleware)],
middleware=[Middleware(CustomMiddleware)],
)

client = test_client_factory(app)
Expand All @@ -216,7 +229,7 @@ class Failed(Exception):
async def index(request):
raise Failed()

class BadMiddleware(HTTPMiddleware):
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
Expand All @@ -225,7 +238,10 @@ async def dispatch(
except Failed:
pass

app = Starlette(routes=[Route("/", index)], middleware=[Middleware(BadMiddleware)])
app = Starlette(
routes=[Route("/", index)],
middleware=[Middleware(CustomMiddleware)],
)

client = test_client_factory(app)
with pytest.raises(RuntimeError, match="no response was returned"):
Expand Down Expand Up @@ -260,9 +276,6 @@ async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
async def homepage(request):
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
Expand Down

0 comments on commit b70f13d

Please sign in to comment.