From ef9d5b00858e80bb021bc172679bd7ab32b7f34c Mon Sep 17 00:00:00 2001 From: David Brownman <109395161+xavdid-stripe@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:53:26 -0800 Subject: [PATCH] Fix using `auto_paging_iter()` with `expand: [...]` (#1434) * deduplicate querystring using a pre-made url * fix tests --- stripe/_api_requestor.py | 43 ++++++++++++++++---- tests/api_resources/test_list_object.py | 53 +++++++++++++++++++++++++ tests/test_api_requestor.py | 9 +++-- 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/stripe/_api_requestor.py b/stripe/_api_requestor.py index 65bb449fe..c0fa8a561 100644 --- a/stripe/_api_requestor.py +++ b/stripe/_api_requestor.py @@ -20,7 +20,7 @@ Unpack, ) import uuid -from urllib.parse import urlsplit, urlunsplit +from urllib.parse import urlsplit, urlunsplit, parse_qs # breaking circular dependency import stripe # noqa: IMP101 @@ -556,6 +556,35 @@ def _args_for_request_with_retries( url, ) + params = params or {} + if params and (method == "get" or method == "delete"): + # if we're sending params in the querystring, then we have to make sure we're not + # duplicating anything we got back from the server already (like in a list iterator) + # so, we parse the querystring the server sends back so we can merge with what we (or the user) are trying to send + existing_params = {} + for k, v in parse_qs(urlsplit(url).query).items(): + # note: server sends back "expand[]" but users supply "expand", so we strip the brackets from the key name + if k.endswith("[]"): + existing_params[k[:-2]] = v + else: + # all querystrings are pulled out as lists. + # We want to keep the querystrings that actually are lists, but flatten the ones that are single values + existing_params[k] = v[0] if len(v) == 1 else v + + # if a user is expanding something that wasn't expanded before, add (and deduplicate) it + # this could theoretically work for other lists that we want to merge too, but that doesn't seem to be a use case + # it never would have worked before, so I think we can start with `expand` and go from there + if "expand" in existing_params and "expand" in params: + params["expand"] = list( # type:ignore - this is a dict + set([*existing_params["expand"], *params["expand"]]) + ) + + params = { + **existing_params, + # user_supplied params take precedence over server params + **params, + } + encoded_params = urlencode(list(_api_encode(params or {}, api_mode))) # Don't use strict form encoding by changing the square bracket control @@ -586,13 +615,13 @@ def _args_for_request_with_retries( if method == "get" or method == "delete": if params: - query = encoded_params - scheme, netloc, path, base_query, fragment = urlsplit(abs_url) + # if we're sending query params, we've already merged the incoming ones with the server's "url" + # so we can overwrite the whole thing + scheme, netloc, path, _, fragment = urlsplit(abs_url) - if base_query: - query = "%s&%s" % (base_query, query) - - abs_url = urlunsplit((scheme, netloc, path, query, fragment)) + abs_url = urlunsplit( + (scheme, netloc, path, encoded_params, fragment) + ) post_data = None elif method == "post": if ( diff --git a/tests/api_resources/test_list_object.py b/tests/api_resources/test_list_object.py index fe6340a14..6de3b4beb 100644 --- a/tests/api_resources/test_list_object.py +++ b/tests/api_resources/test_list_object.py @@ -3,6 +3,7 @@ import pytest import stripe +from tests.http_client_mock import HTTPClientMock class TestListObject(object): @@ -439,6 +440,58 @@ def test_forwards_api_key_to_nested_resources(self, http_client_mock): ) assert lo.data[0].api_key == "sk_test_iter_forwards_options" + def test_iter_with_params(self, http_client_mock: HTTPClientMock): + http_client_mock.stub_request( + "get", + path="/v1/invoices/upcoming/lines", + query_string="customer=cus_123&expand[0]=data.price&limit=1", + rbody=json.dumps( + { + "object": "list", + "data": [ + { + "id": "prod_001", + "object": "product", + "price": {"object": "price", "id": "price_123"}, + } + ], + "url": "/v1/invoices/upcoming/lines?customer=cus_123&expand%5B%5D=data.price", + "has_more": True, + } + ), + ) + # second page + http_client_mock.stub_request( + "get", + path="/v1/invoices/upcoming/lines", + query_string="customer=cus_123&expand[0]=data.price&limit=1&starting_after=prod_001", + rbody=json.dumps( + { + "object": "list", + "data": [ + { + "id": "prod_002", + "object": "product", + "price": {"object": "price", "id": "price_123"}, + } + ], + "url": "/v1/invoices/upcoming/lines?customer=cus_123&expand%5B%5D=data.price", + "has_more": False, + } + ), + ) + + lo = stripe.Invoice.upcoming_lines( + api_key="sk_test_invoice_lines", + customer="cus_123", + expand=["data.price"], + limit=1, + ) + + seen = [item["id"] for item in lo.auto_paging_iter()] + + assert seen == ["prod_001", "prod_002"] + class TestAutoPagingAsync: @staticmethod diff --git a/tests/test_api_requestor.py b/tests/test_api_requestor.py index 82f10e5a3..0c8b028d4 100644 --- a/tests/test_api_requestor.py +++ b/tests/test_api_requestor.py @@ -245,16 +245,17 @@ def test_ordereddict_encoding(self): def test_url_construction(self, requestor, http_client_mock): CASES = ( - ("%s?foo=bar" % stripe.api_base, "", {"foo": "bar"}), - ("%s?foo=bar" % stripe.api_base, "?", {"foo": "bar"}), + (f"{stripe.api_base}?foo=bar", "", {"foo": "bar"}), + (f"{stripe.api_base}?foo=bar", "?", {"foo": "bar"}), (stripe.api_base, "", {}), ( - "%s/%%20spaced?foo=bar%%24&baz=5" % stripe.api_base, + f"{stripe.api_base}/%20spaced?baz=5&foo=bar%24", "/%20spaced?foo=bar%24", {"baz": "5"}, ), + # duplicate query params keys should be deduped ( - "%s?foo=bar&foo=bar" % stripe.api_base, + f"{stripe.api_base}?foo=bar", "?foo=bar", {"foo": "bar"}, ),