Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mock requests.Session.get in TestClient #150

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 35 additions & 33 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import arxiv
from datetime import datetime, timedelta
from pytest import approx
from requests import Response

def empty_response(code: int) -> Response:
r = Response()
r.status_code = code
r._content = b''
return r

class TestClient(unittest.TestCase):
def test_invalid_format_id(self):
Expand Down Expand Up @@ -90,10 +96,10 @@ def test_no_duplicates(self):
self.assertFalse(r.entry_id in ids)
ids.add(r.entry_id)

@patch('requests.Session.get', return_value=empty_response(500))
@patch("time.sleep", return_value=None)
def test_retry(self, patched_time_sleep):
broken_client = TestClient.get_code_client(500)

def test_retry(self, mock_sleep, mock_get):
broken_client = arxiv.Client()
def broken_get():
search = arxiv.Search(query="quantum")
return next(broken_client.results(search))
Expand All @@ -109,77 +115,73 @@ def broken_get():
self.assertEqual(e.status, 500)
self.assertEqual(e.retry, broken_client.num_retries)

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_standard(self, patched_time_sleep):
client = TestClient.get_code_client(200)
def test_sleep_standard(self, mock_sleep, mock_get):
client = arxiv.Client()
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
# A client should sleep until delay_seconds have passed.
client._parse_feed(url)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()
# Overwrite _last_request_dt to minimize flakiness: different
# environments will have different page fetch times.
client._last_request_dt = datetime.now()
client._parse_feed(url)
patched_time_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))
mock_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_multiple_requests(self, patched_time_sleep):
client = TestClient.get_code_client(200)
def test_sleep_multiple_requests(self, mock_sleep, mock_get):
client = arxiv.Client()
url1 = client._format_url(arxiv.Search(query="quantum"), 0, 1)
url2 = client._format_url(arxiv.Search(query="testing"), 0, 1)
# Rate limiting is URL-independent; expect same behavior as in
# `test_sleep_standard`.
client._parse_feed(url1)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()
client._last_request_dt = datetime.now()
client._parse_feed(url2)
patched_time_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))
mock_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_elapsed(self, patched_time_sleep):
client = TestClient.get_code_client(200)
def test_sleep_elapsed(self, mock_sleep, mock_get):
client = arxiv.Client()
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
# If _last_request_dt is less than delay_seconds ago, sleep.
client._last_request_dt = datetime.now() - timedelta(seconds=client.delay_seconds - 1)
client._parse_feed(url)
patched_time_sleep.assert_called_once()
patched_time_sleep.reset_mock()
mock_sleep.assert_called_once()
mock_sleep.reset_mock()
# If _last_request_dt is at least delay_seconds ago, don't sleep.
client._last_request_dt = datetime.now() - timedelta(seconds=client.delay_seconds)
client._parse_feed(url)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_zero_delay(self, patched_time_sleep):
client = TestClient.get_code_client(code=200, delay_seconds=0)
def test_sleep_zero_delay(self, mock_sleep, mock_get):
client = arxiv.Client(delay_seconds=0)
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
client._parse_feed(url)
client._parse_feed(url)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()

@patch('requests.Session.get', return_value=empty_response(500))
@patch("time.sleep", return_value=None)
def test_sleep_between_errors(self, patched_time_sleep):
client = TestClient.get_code_client(500)
def test_sleep_between_errors(self, mock_sleep, mock_get):
client = arxiv.Client()
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
try:
client._parse_feed(url)
except arxiv.HTTPError:
pass
# Should sleep between retries.
patched_time_sleep.assert_called()
self.assertEqual(patched_time_sleep.call_count, client.num_retries)
patched_time_sleep.assert_has_calls(
mock_sleep.assert_called()
self.assertEqual(mock_sleep.call_count, client.num_retries)
mock_sleep.assert_has_calls(
[
call(approx(client.delay_seconds, abs=1e-2)),
]
* client.num_retries
)

def get_code_client(code: int, delay_seconds=0.1, num_retries=3) -> arxiv.Client:
"""
get_code_client returns an arxiv.Cient with HTTP requests routed to
httpstat.us.
"""
client = arxiv.Client(delay_seconds=delay_seconds, num_retries=num_retries)
client.query_url_format = "https://teapot.fly.dev/{}?".format(code) + "{}"
return client