diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py index aa1675263..76298e4d1 100644 --- a/starlette/middleware/http.py +++ b/starlette/middleware/http.py @@ -66,9 +66,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return response_started = False + response_finished = False async def wrapped_send(message: Message) -> None: - nonlocal response_started + nonlocal response_started, response_finished if message["type"] == "http.response.start": response_started = True @@ -77,17 +78,27 @@ async def wrapped_send(message: Message) -> None: response.raw_headers.clear() try: - await flow.asend(response) + new_response = await flow.asend(response) except StopAsyncIteration: pass else: - raise RuntimeError("dispatch() should yield exactly once") - + if new_response is None: + raise RuntimeError("dispatch() should yield exactly once") + try: + await flow.__anext__() + except StopAsyncIteration: + pass + else: + raise RuntimeError("dispatch() should yield exactly once") + await new_response(scope, receive, send) + response_finished = True + return headers = MutableHeaders(raw=message["headers"]) headers.update(response.headers) message["headers"] = headers.raw - await send(message) + if not response_finished: + await send(message) try: await self.app(scope, receive, wrapped_send) diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py index cbe1c3ac7..1add61fcf 100644 --- a/tests/middleware/test_http.py +++ b/tests/middleware/test_http.py @@ -169,6 +169,79 @@ async def dispatch( client.get("/") +def test_replace_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: + async def index(request: Request) -> Response: + return PlainTextResponse("Hello, world!") + + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Optional[Response], Response]: + yield None + yield PlainTextResponse("Custom") + + app = Starlette( + routes=[Route("/", index)], + middleware=[Middleware(CustomMiddleware)], + ) + + client = test_client_factory(app) + + resp = client.get("/") + assert resp.text == "Custom" + + +def test_replace_response_too_many_yields( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + async def index(request: Request) -> Response: + return PlainTextResponse("Hello, world!") + + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Optional[Response], Response]: + yield None + yield PlainTextResponse("Custom") + yield None + + app = Starlette( + routes=[Route("/", index)], + middleware=[Middleware(CustomMiddleware)], + ) + + client = test_client_factory(app) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="should yield exactly once"): + client.get("/") + + +def test_replace_response_yield_None( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + async def index(request: Request) -> Response: + return PlainTextResponse("Hello, world!") + + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Optional[Response], Response]: + yield None + yield None + + app = Starlette( + routes=[Route("/", index)], + middleware=[Middleware(CustomMiddleware)], + ) + + client = test_client_factory(app) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="should yield exactly once"): + client.get("/") + + def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: class Failed(Exception): pass