diff --git a/tests/test_server.py b/tests/test_server.py index 5b9d9ba2d..c650be290 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,15 +2,23 @@ import asyncio import contextlib +import logging import signal import sys from typing import Callable, ContextManager, Generator +import httpx import pytest +from tests.utils import run_server +from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.config import Config +from uvicorn.protocols.http.h11_impl import H11Protocol +from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.server import Server +pytestmark = pytest.mark.anyio + # asyncio does NOT allow raising in signal handlers, so to detect # raised signals raised a mutable `witness` receives the signal @@ -37,6 +45,12 @@ async def dummy_app(scope, receive, send): # pragma: py-win32 pass +async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: + assert scope["type"] == "http" + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + if sys.platform == "win32": # pragma: py-not-win32 signals = [signal.SIGBREAK] signal_captures = [capture_signal_sync] @@ -45,7 +59,6 @@ async def dummy_app(scope, receive, send): # pragma: py-win32 signal_captures = [capture_signal_sync, capture_signal_async] -@pytest.mark.anyio @pytest.mark.parametrize("exception_signal", signals) @pytest.mark.parametrize("capture_signal", signal_captures) async def test_server_interrupt( @@ -65,3 +78,16 @@ async def interrupt_running(srv: Server): assert witness # set by the server's graceful exit handler assert server.should_exit + + +async def test_request_than_limit_max_requests_warn_log( + unused_tcp_port: int, http_protocol_cls: type[H11Protocol | HttpToolsProtocol], caplog: pytest.LogCaptureFixture +): + caplog.set_level(logging.WARNING, logger="uvicorn.error") + config = Config(app=app, limit_max_requests=1, port=unused_tcp_port, http=http_protocol_cls) + async with run_server(config): + async with httpx.AsyncClient() as client: + tasks = [client.get(f"http://127.0.0.1:{unused_tcp_port}") for _ in range(2)] + responses = await asyncio.gather(*tasks) + assert len(responses) == 2 + assert "Maximum request limit of 1 exceeded. Terminating process." in caplog.text diff --git a/tools/cli_usage.py b/tools/cli_usage.py index 5a9115710..1cfa88672 100644 --- a/tools/cli_usage.py +++ b/tools/cli_usage.py @@ -2,24 +2,26 @@ Look for a marker comment in docs pages, and place the output of `$ uvicorn --help` there. Pass `--check` to ensure the content is in sync. """ + +from __future__ import annotations + import argparse import subprocess import sys -import typing from pathlib import Path -def _get_usage_lines() -> typing.List[str]: +def _get_usage_lines() -> list[str]: res = subprocess.run(["uvicorn", "--help"], stdout=subprocess.PIPE) help_text = res.stdout.decode("utf-8") return ["```", "$ uvicorn --help", *help_text.splitlines(), "```"] -def _find_next_codefence_lineno(lines: typing.List[str], after: int) -> int: +def _find_next_codefence_lineno(lines: list[str], after: int) -> int: return next(lineno for lineno, line in enumerate(lines[after:], after) if line == "```") -def _get_insert_location(lines: typing.List[str]) -> typing.Tuple[int, int]: +def _get_insert_location(lines: list[str]) -> tuple[int, int]: marker = lines.index("") start = marker + 1 diff --git a/uvicorn/server.py b/uvicorn/server.py index fa7638b7d..f14026f16 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -250,8 +250,12 @@ async def on_tick(self, counter: int) -> bool: # Determine if we should exit. if self.should_exit: return True - if self.config.limit_max_requests is not None: - return self.server_state.total_requests >= self.config.limit_max_requests + + max_requests = self.config.limit_max_requests + if max_requests is not None and self.server_state.total_requests >= max_requests: + logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.") + return True + return False async def shutdown(self, sockets: list[socket.socket] | None = None) -> None: