Skip to content

Commit

Permalink
Improve WSGI compliance
Browse files Browse the repository at this point in the history
The response body is closed if it has a close method as per PEP
3333. In addition the response headers are only sent when the first
response body byte is available to send. Finally, an error is raised
if start_response has not been called by the app.
  • Loading branch information
pgjones committed Dec 27, 2023
1 parent cb443a4 commit 2d2c62b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 33 deletions.
23 changes: 19 additions & 4 deletions src/hypercorn/app_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,40 @@ async def handle_http(

def run_app(self, environ: dict, send: Callable) -> None:
headers: List[Tuple[bytes, bytes]]
headers_sent = False
response_started = False
status_code: Optional[int] = None

def start_response(
status: str,
response_headers: List[Tuple[str, str]],
exc_info: Optional[Exception] = None,
) -> None:
nonlocal headers, status_code
nonlocal headers, response_started, status_code

raw, _ = status.split(" ", 1)
status_code = int(raw)
headers = [
(name.lower().encode("ascii"), value.encode("ascii"))
for name, value in response_headers
]
send({"type": "http.response.start", "status": status_code, "headers": headers})
response_started = True

for output in self.app(environ, start_response):
send({"type": "http.response.body", "body": output, "more_body": True})
response_body = self.app(environ, start_response)

if not response_started:
raise RuntimeError("WSGI app did not call start_response")

try:
for output in response_body:
if not headers_sent:
send({"type": "http.response.start", "status": status_code, "headers": headers})
headers_sent = True

send({"type": "http.response.body", "body": output, "more_body": True})
finally:
if hasattr(response_body, "close"):
response_body.close()


def _build_environ(scope: HTTPScope, body: bytes) -> dict:
Expand Down
79 changes: 50 additions & 29 deletions tests/test_app_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,28 @@ async def _send(message: ASGISendEvent) -> None:
]


async def _run_app(app: WSGIWrapper, scope: HTTPScope, body: bytes = b"") -> List[ASGISendEvent]:
queue: asyncio.Queue = asyncio.Queue()
await queue.put({"type": "http.request", "body": body})

messages = []

async def _send(message: ASGISendEvent) -> None:
nonlocal messages
messages.append(message)

event_loop = asyncio.get_running_loop()

def _call_soon(func: Callable, *args: Any) -> Any:
future = asyncio.run_coroutine_threadsafe(func(*args), event_loop)
return future.result()

await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon)
return messages


@pytest.mark.asyncio
async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None:
async def test_wsgi_asyncio() -> None:
app = WSGIWrapper(echo_body, 2**16)
scope: HTTPScope = {
"http_version": "1.1",
Expand All @@ -79,20 +99,7 @@ async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None:
"server": None,
"extensions": {},
}
queue: asyncio.Queue = asyncio.Queue()
await queue.put({"type": "http.request"})

messages = []

async def _send(message: ASGISendEvent) -> None:
nonlocal messages
messages.append(message)

def _call_soon(func: Callable, *args: Any) -> Any:
future = asyncio.run_coroutine_threadsafe(func(*args), event_loop)
return future.result()

await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon)
messages = await _run_app(app, scope)
assert messages == [
{
"headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")],
Expand All @@ -105,7 +112,7 @@ def _call_soon(func: Callable, *args: Any) -> Any:


@pytest.mark.asyncio
async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None:
async def test_max_body_size() -> None:
app = WSGIWrapper(echo_body, 4)
scope: HTTPScope = {
"http_version": "1.1",
Expand All @@ -122,25 +129,39 @@ async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None:
"server": None,
"extensions": {},
}
queue: asyncio.Queue = asyncio.Queue()
await queue.put({"type": "http.request", "body": b"abcde"})
messages = []

async def _send(message: ASGISendEvent) -> None:
nonlocal messages
messages.append(message)

def _call_soon(func: Callable, *args: Any) -> Any:
future = asyncio.run_coroutine_threadsafe(func(*args), event_loop)
return future.result()

await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon)
messages = await _run_app(app, scope, b"abcde")
assert messages == [
{"headers": [], "status": 400, "type": "http.response.start"},
{"body": bytearray(b""), "type": "http.response.body", "more_body": False},
]


def no_start_response(environ: dict, start_response: Callable) -> List[bytes]:
return [b"result"]


@pytest.mark.asyncio
async def test_no_start_response() -> None:
app = WSGIWrapper(no_start_response, 2**16)
scope: HTTPScope = {
"http_version": "1.1",
"asgi": {},
"method": "GET",
"headers": [],
"path": "/",
"root_path": "/",
"query_string": b"a=b",
"raw_path": b"/",
"scheme": "http",
"type": "http",
"client": ("localhost", 80),
"server": None,
"extensions": {},
}
with pytest.raises(RuntimeError):
await _run_app(app, scope)


def test_build_environ_encoding() -> None:
scope: HTTPScope = {
"http_version": "1.0",
Expand Down

0 comments on commit 2d2c62b

Please sign in to comment.