diff --git a/src/hypercorn/app_wrappers.py b/src/hypercorn/app_wrappers.py index 769e014b..d620fc93 100644 --- a/src/hypercorn/app_wrappers.py +++ b/src/hypercorn/app_wrappers.py @@ -78,12 +78,15 @@ async def handle_http( environ = _build_environ(scope, body) except InvalidPathError: await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) else: await sync_spawn(self.run_app, environ, partial(call_soon, send)) - await send({"type": "http.response.body", "body": b"", "more_body": False}) def run_app(self, environ: dict, send: Callable) -> None: headers: List[Tuple[bytes, bytes]] + headers_set: bool = False + headers_sent: bool = False + implicit_content_length: bool = False status_code: Optional[int] = None def start_response( @@ -91,7 +94,7 @@ def start_response( response_headers: List[Tuple[str, str]], exc_info: Optional[Exception] = None, ) -> None: - nonlocal headers, status_code + nonlocal headers, headers_set, status_code raw, _ = status.split(" ", 1) status_code = int(raw) @@ -99,10 +102,39 @@ def start_response( (name.lower().encode("ascii"), value.encode("ascii")) for name, value in response_headers ] + headers_set = True + + def send_headers(content_length: int): + nonlocal headers, headers_sent + if not headers_set: + raise AssertionError("missing call to start_response") + if implicit_content_length: # We can determine the content-length ourself if not set + for name, _ in headers: + if name == b"content-length": + break + else: # No content-length set by the application + headers.append((b"content-length", str(content_length).encode("ascii"))) send({"type": "http.response.start", "status": status_code, "headers": headers}) + headers_sent = True - for output in self.app(environ, start_response): - send({"type": "http.response.body", "body": output, "more_body": True}) + def send_body(data: bytes): + if not headers_sent: + send_headers(len(data)) + send({"type": "http.response.body", "body": data, "more_body": True if data else False}) + + response_body_iter = self.app(environ, start_response) + + if hasattr(response_body_iter, "__len__") and len(response_body_iter) == 1: + implicit_content_length = True + + try: + for output in response_body_iter: + if output: + send_body(output) + send_body(b"") + finally: + if hasattr(response_body_iter, "close"): + response_body_iter.close() def _build_environ(scope: HTTPScope, body: bytes) -> dict: