From 2d2c62bac7b83a8c6766fe3a517f63ff842e5c38 Mon Sep 17 00:00:00 2001 From: pgjones Date: Wed, 27 Dec 2023 17:13:34 +0000 Subject: [PATCH] Improve WSGI compliance The response body is closed if it has a close method as per PEP 3333. In addition the response headers are only sent when the first response body byte is available to send. Finally, an error is raised if start_response has not been called by the app. --- src/hypercorn/app_wrappers.py | 23 ++++++++-- tests/test_app_wrappers.py | 79 ++++++++++++++++++++++------------- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/src/hypercorn/app_wrappers.py b/src/hypercorn/app_wrappers.py index 769e014b..cfc41cfd 100644 --- a/src/hypercorn/app_wrappers.py +++ b/src/hypercorn/app_wrappers.py @@ -84,6 +84,8 @@ async def handle_http( def run_app(self, environ: dict, send: Callable) -> None: headers: List[Tuple[bytes, bytes]] + headers_sent = False + response_started = False status_code: Optional[int] = None def start_response( @@ -91,7 +93,7 @@ def start_response( response_headers: List[Tuple[str, str]], exc_info: Optional[Exception] = None, ) -> None: - nonlocal headers, status_code + nonlocal headers, response_started, status_code raw, _ = status.split(" ", 1) status_code = int(raw) @@ -99,10 +101,23 @@ def start_response( (name.lower().encode("ascii"), value.encode("ascii")) for name, value in response_headers ] - send({"type": "http.response.start", "status": status_code, "headers": headers}) + response_started = True - for output in self.app(environ, start_response): - send({"type": "http.response.body", "body": output, "more_body": True}) + response_body = self.app(environ, start_response) + + if not response_started: + raise RuntimeError("WSGI app did not call start_response") + + try: + for output in response_body: + if not headers_sent: + send({"type": "http.response.start", "status": status_code, "headers": headers}) + headers_sent = True + + send({"type": "http.response.body", "body": output, "more_body": True}) + finally: + if hasattr(response_body, "close"): + response_body.close() def _build_environ(scope: HTTPScope, body: bytes) -> dict: diff --git a/tests/test_app_wrappers.py b/tests/test_app_wrappers.py index bb7b5897..c68ba0cb 100644 --- a/tests/test_app_wrappers.py +++ b/tests/test_app_wrappers.py @@ -61,8 +61,28 @@ async def _send(message: ASGISendEvent) -> None: ] +async def _run_app(app: WSGIWrapper, scope: HTTPScope, body: bytes = b"") -> List[ASGISendEvent]: + queue: asyncio.Queue = asyncio.Queue() + await queue.put({"type": "http.request", "body": body}) + + messages = [] + + async def _send(message: ASGISendEvent) -> None: + nonlocal messages + messages.append(message) + + event_loop = asyncio.get_running_loop() + + def _call_soon(func: Callable, *args: Any) -> Any: + future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) + return future.result() + + await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + return messages + + @pytest.mark.asyncio -async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_wsgi_asyncio() -> None: app = WSGIWrapper(echo_body, 2**16) scope: HTTPScope = { "http_version": "1.1", @@ -79,20 +99,7 @@ async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None: "server": None, "extensions": {}, } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request"}) - - messages = [] - - async def _send(message: ASGISendEvent) -> None: - nonlocal messages - messages.append(message) - - def _call_soon(func: Callable, *args: Any) -> Any: - future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) - return future.result() - - await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + messages = await _run_app(app, scope) assert messages == [ { "headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")], @@ -105,7 +112,7 @@ def _call_soon(func: Callable, *args: Any) -> Any: @pytest.mark.asyncio -async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None: +async def test_max_body_size() -> None: app = WSGIWrapper(echo_body, 4) scope: HTTPScope = { "http_version": "1.1", @@ -122,25 +129,39 @@ async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None: "server": None, "extensions": {}, } - queue: asyncio.Queue = asyncio.Queue() - await queue.put({"type": "http.request", "body": b"abcde"}) - messages = [] - - async def _send(message: ASGISendEvent) -> None: - nonlocal messages - messages.append(message) - - def _call_soon(func: Callable, *args: Any) -> Any: - future = asyncio.run_coroutine_threadsafe(func(*args), event_loop) - return future.result() - - await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon) + messages = await _run_app(app, scope, b"abcde") assert messages == [ {"headers": [], "status": 400, "type": "http.response.start"}, {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, ] +def no_start_response(environ: dict, start_response: Callable) -> List[bytes]: + return [b"result"] + + +@pytest.mark.asyncio +async def test_no_start_response() -> None: + app = WSGIWrapper(no_start_response, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + } + with pytest.raises(RuntimeError): + await _run_app(app, scope) + + def test_build_environ_encoding() -> None: scope: HTTPScope = { "http_version": "1.0",