From a23ae268ad49cb44f6c3df16d489ce2d0ae65070 Mon Sep 17 00:00:00 2001 From: Ben Falk Date: Tue, 8 Dec 2020 11:38:09 -0500 Subject: [PATCH 1/3] add timeout support for testclient raise ValueError if user submits tuple simplify by removing inline function typing updates --- starlette/testclient.py | 17 +++++++++++++++-- tests/test_testclient.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index c17c51819..7232befed 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -10,6 +10,7 @@ from urllib.parse import unquote, urljoin, urlsplit import requests +from urllib3.util.timeout import Timeout from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -96,7 +97,11 @@ def __init__( self.root_path = root_path def send( - self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any + self, + request: requests.PreparedRequest, + *args: typing.Any, + timeout: Timeout = None, + **kwargs: typing.Any, ) -> requests.Response: scheme, netloc, path, query, fragment = ( str(item) for item in urlsplit(request.url) @@ -237,7 +242,15 @@ async def send(message: Message) -> None: asyncio.set_event_loop(loop) try: - loop.run_until_complete(self.app(scope, receive, send)) + if isinstance(timeout, tuple): + err = ( + "Invalid timeout {}. testclient only supports float (not tuple)" + "at this time ".format(timeout) + ) + raise ValueError(err) + loop.run_until_complete( + asyncio.wait_for(self.app(scope, receive, send), timeout) + ) except BaseException as exc: if self.raise_server_exceptions: raise exc from None diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 00f4e0125..bb0ed05b5 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,4 +1,5 @@ import asyncio +import time import pytest @@ -16,6 +17,19 @@ def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) +@mock_service.route("/slow_response") +def slow_response(request): + time.sleep(0.01) + return JSONResponse({"mock": "slow example"}) + + +@mock_service.route("/async_slow_response") +async def async_slow_response(request): + # time.sleep(0.01) + await asyncio.sleep(0.01) + return JSONResponse({"mock": "slow example"}) + + app = Starlette() @@ -132,3 +146,20 @@ async def asgi(receive, send): with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} + + +@pytest.mark.parametrize("endpoint", ["/slow_response", "/async_slow_response"]) +def test_timeout(endpoint): + client = TestClient(mock_service, raise_server_exceptions=True) + + with pytest.raises(ValueError): + client.get("/slow_response", timeout=(1, 1)) + + with pytest.raises(asyncio.TimeoutError): + client.get(endpoint, timeout=0.001) + + response = client.get(endpoint, timeout=1) + assert response.json() == {"mock": "slow example"} + + response = client.get(endpoint) + assert response.json() == {"mock": "slow example"} From 3ea2f7d57e05a619e4198edef85c593e04183f0b Mon Sep 17 00:00:00 2001 From: Ben Falk Date: Tue, 8 Dec 2020 14:08:30 -0500 Subject: [PATCH 2/3] use parametrize value in test --- tests/test_testclient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index bb0ed05b5..143594e4e 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -153,7 +153,7 @@ def test_timeout(endpoint): client = TestClient(mock_service, raise_server_exceptions=True) with pytest.raises(ValueError): - client.get("/slow_response", timeout=(1, 1)) + client.get(endpoint, timeout=(1, 1)) with pytest.raises(asyncio.TimeoutError): client.get(endpoint, timeout=0.001) From 2403e82e4090e79309853fe5841152da65a494c2 Mon Sep 17 00:00:00 2001 From: Ben Falk Date: Tue, 8 Dec 2020 15:21:14 -0500 Subject: [PATCH 3/3] add test raise server exception=False --- tests/test_testclient.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 143594e4e..d702ad195 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -163,3 +163,17 @@ def test_timeout(endpoint): response = client.get(endpoint) assert response.json() == {"mock": "slow example"} + + client = TestClient(mock_service, raise_server_exceptions=False) + + response = client.get(endpoint, timeout=(1, 1)) + assert response.status_code == 500 + + response = client.get(endpoint, timeout=0.001) + assert response.status_code == 500 + + response = client.get(endpoint, timeout=1) + assert response.json() == {"mock": "slow example"} + + response = client.get(endpoint) + assert response.json() == {"mock": "slow example"}