Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: unify http client mock #1242

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions tests/api_resources/abstract/test_custom_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_call_custom_list_method_class_paginates(self, http_client_mock):

assert ids == ["cus_1", "cus_2", "cus_3"]

def test_call_custom_stream_method_class(self, http_client_mock_streaming):
http_client_mock_streaming.stub_request(
def test_call_custom_stream_method_class(self, http_client_mock):
http_client_mock.stub_request(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
rbody=util.io.BytesIO(str.encode("response body")),
Expand All @@ -119,7 +119,7 @@ def test_call_custom_stream_method_class(self, http_client_mock_streaming):

resp = self.MyResource.do_stream_stuff("mid", foo="bar")

http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
post_data="foo=bar",
Expand Down Expand Up @@ -150,9 +150,9 @@ def test_call_custom_method_class_with_object(self, http_client_mock):
assert obj.thing_done is True

def test_call_custom_stream_method_class_with_object(
self, http_client_mock_streaming
self, http_client_mock
):
http_client_mock_streaming.stub_request(
http_client_mock.stub_request(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
rbody=util.io.BytesIO(str.encode("response body")),
Expand All @@ -162,7 +162,7 @@ def test_call_custom_stream_method_class_with_object(
obj = self.MyResource.construct_from({"id": "mid"}, "mykey")
resp = self.MyResource.do_stream_stuff(obj, foo="bar")

http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
post_data="foo=bar",
Expand Down Expand Up @@ -192,10 +192,8 @@ def test_call_custom_method_instance(self, http_client_mock):
)
assert obj.thing_done is True

def test_call_custom_stream_method_instance(
self, http_client_mock_streaming
):
http_client_mock_streaming.stub_request(
def test_call_custom_stream_method_instance(self, http_client_mock):
http_client_mock.stub_request(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
rbody=util.io.BytesIO(str.encode("response body")),
Expand All @@ -205,7 +203,7 @@ def test_call_custom_stream_method_instance(
obj = self.MyResource.construct_from({"id": "mid"}, "mykey")
resp = obj.do_stream_stuff(foo="bar")

http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"post",
path="/v1/myresources/mid/do_the_stream_thing",
post_data="foo=bar",
Expand Down
8 changes: 4 additions & 4 deletions tests/api_resources/test_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def test_can_list_computed_upfront_line_items_classmethod(
)
assert isinstance(resources.data[0], stripe.LineItem)

def test_can_pdf(self, setup_upload_api_base, http_client_mock_streaming):
def test_can_pdf(self, setup_upload_api_base, http_client_mock):
resource = stripe.Quote.retrieve(TEST_RESOURCE_ID)
stream = resource.pdf()
http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"get",
api_base=stripe.upload_api_base,
path="/v1/quotes/%s/pdf" % TEST_RESOURCE_ID,
Expand All @@ -152,10 +152,10 @@ def test_can_pdf(self, setup_upload_api_base, http_client_mock_streaming):
assert content == b"Stripe binary response"

def test_can_pdf_classmethod(
self, setup_upload_api_base, http_client_mock_streaming
self, setup_upload_api_base, http_client_mock
):
stream = stripe.Quote.pdf(TEST_RESOURCE_ID)
http_client_mock_streaming.assert_requested(
http_client_mock.assert_requested(
"get",
api_base=stripe.upload_api_base,
path="/v1/quotes/%s/pdf" % TEST_RESOURCE_ID,
Expand Down
27 changes: 0 additions & 27 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ def http_client_mock(mocker):
stripe.default_http_client = old_client


@pytest.fixture
def http_client_mock_streaming(mocker):
mock_client = HTTPClientMock(mocker, is_streaming=True)
old_client = stripe.default_http_client
stripe.default_http_client = mock_client.get_mock_http_client()
yield mock_client
stripe.default_http_client = old_client


@pytest.fixture
def stripe_mock_stripe_client(http_client_mock):
return StripeClient(
Expand All @@ -115,21 +106,3 @@ def file_stripe_mock_stripe_client(http_client_mock):
base_addresses={"files": MOCK_API_BASE},
http_client=http_client_mock.get_mock_http_client(),
)


@pytest.fixture
def stripe_mock_stripe_client_streaming(http_client_mock_streaming):
return StripeClient(
MOCK_API_KEY,
base_addresses={"api": MOCK_API_BASE},
http_client=http_client_mock_streaming.get_mock_http_client(),
)


@pytest.fixture
def file_stripe_mock_stripe_client_streaming(http_client_mock_streaming):
return StripeClient(
MOCK_API_KEY,
base_addresses={"files": MOCK_API_BASE},
http_client=http_client_mock_streaming.get_mock_http_client(),
)
148 changes: 73 additions & 75 deletions tests/http_client_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,28 +212,19 @@ def assert_post_data(self, expected, is_json=False):


class HTTPClientMock(object):
def __init__(self, mocker, is_streaming=False, is_async=False):
if is_async:
self.mock_client = mocker.Mock(
wraps=stripe.http_client.new_default_http_client_async()
)
else:
self.mock_client = mocker.Mock(
wraps=stripe.http_client.new_default_http_client()
)
def __init__(self, mocker):
self.mock_client = mocker.Mock(
wraps=stripe.http_client.new_default_http_client()
)

self.is_async = is_async
self.mock_client._verify_ssl_certs = True
self.mock_client.name = "mockclient"
if is_async and is_streaming:
self.func = self.mock_client.request_stream_with_retries_async
elif is_async and not is_streaming:
self.func = self.mock_client.request_with_retries_async
elif is_streaming:
self.func = self.mock_client.request_stream_with_retries
else:
self.func = self.mock_client.request_with_retries
self.registered_responses = {}
self.funcs = [
self.mock_client.request_with_retries,
self.mock_client.request_stream_with_retries,
]
self.func_call_order = []

def get_mock_http_client(self) -> Mock:
return self.mock_client
Expand All @@ -247,73 +238,78 @@ def stub_request(
rcode=200,
rheaders={},
) -> None:
def custom_side_effect(called_method, called_abs_url, *args, **kwargs):
called_path = urlsplit(called_abs_url).path
called_query = ""
if urlsplit(called_abs_url).query:
called_query = urlencode(
parse_and_sort(urlsplit(called_abs_url).query)
)
if (
called_method,
called_path,
called_query,
) not in self.registered_responses:
raise AssertionError(
"Unexpected request made to %s %s %s"
% (called_method, called_path, called_query)
)
return self.registered_responses[
(called_method, called_path, called_query)
]

async def awaitable(x):
return x
def custom_side_effect_for_func(func):
def custom_side_effect(
called_method, called_abs_url, *args, **kwargs
):
self.func_call_order.append(func)
called_path = urlsplit(called_abs_url).path
called_query = ""
if urlsplit(called_abs_url).query:
called_query = urlencode(
parse_and_sort(urlsplit(called_abs_url).query)
)
if (
called_method,
called_path,
called_query,
) not in self.registered_responses:
raise AssertionError(
"Unexpected request made to %s %s %s"
% (called_method, called_path, called_query)
)
ret = self.registered_responses[
(called_method, called_path, called_query)
]
return ret

return custom_side_effect

self.registered_responses[
(method, path, urlencode(parse_and_sort(query_string)))
] = (
awaitable(
(
rbody,
rcode,
rheaders,
)
)
if self.is_async
else (rbody, rcode, rheaders)
)
] = (rbody, rcode, rheaders)

self.func.side_effect = custom_side_effect
for func in self.funcs:
func.side_effect = custom_side_effect_for_func(func)

def get_last_call(self) -> StripeRequestCall:
if not self.func.called:
if len(self.func_call_order) == 0:
raise AssertionError(
"Expected request to have been made, but no calls were found."
)
return StripeRequestCall.from_mock_call(self.func.call_args)
return StripeRequestCall.from_mock_call(
self.func_call_order[-1].call_args
)

def get_all_calls(self) -> List[StripeRequestCall]:
calls_by_func = {
func: list(func.call_args_list) for func in self.funcs
}

calls = []
for func in self.func_call_order:
calls.append(calls_by_func[func].pop(0))

return [
StripeRequestCall.from_mock_call(call_args)
for call_args in self.func.call_args_list
StripeRequestCall.from_mock_call(call_args) for call_args in calls
]

def find_call(
self, method, api_base, path, query_string
) -> StripeRequestCall:
for call_args in self.func.call_args_list:
request_call = StripeRequestCall.from_mock_call(call_args)
try:
if request_call.check(
method=method,
api_base=api_base,
path=path,
query_string=query_string,
):
return request_call
except AssertionError:
pass
for func in self.funcs:
for call_args in func.call_args_list:
request_call = StripeRequestCall.from_mock_call(call_args)
try:
if request_call.check(
method=method,
api_base=api_base,
path=path,
query_string=query_string,
):
return request_call
except AssertionError:
pass
raise AssertionError(
"Expected request to have been made, but no calls were found."
)
Expand Down Expand Up @@ -369,13 +365,15 @@ def assert_requested(
)

def assert_no_request(self):
if self.func.called:
msg = (
"Expected no request to have been made, but %s calls were "
"found." % (self.func.call_count)
)
raise AssertionError(msg)
for func in self.funcs:
if func.called:
msg = (
"Expected no request to have been made, but %s calls were "
"found." % (sum([func.call_count for func in self.funcs]))
)
raise AssertionError(msg)

def reset_mock(self):
self.func.reset_mock()
for func in self.funcs:
func.reset_mock()
self.registered_responses = {}
10 changes: 4 additions & 6 deletions tests/services/test_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,11 @@ def test_can_list_computed_upfront_line_items(

def test_can_pdf(
self,
file_stripe_mock_stripe_client_streaming,
http_client_mock_streaming,
file_stripe_mock_stripe_client,
http_client_mock,
):
stream = file_stripe_mock_stripe_client_streaming.quotes.pdf(
TEST_RESOURCE_ID
)
http_client_mock_streaming.assert_requested(
stream = file_stripe_mock_stripe_client.quotes.pdf(TEST_RESOURCE_ID)
http_client_mock.assert_requested(
"get",
api_base=stripe.upload_api_base,
path="/v1/quotes/%s/pdf" % TEST_RESOURCE_ID,
Expand Down
Loading
Loading