Skip to content

Commit

Permalink
deduplicate querystring using a pre-made url
Browse files Browse the repository at this point in the history
  • Loading branch information
xavdid-stripe committed Dec 14, 2024
1 parent 82928b0 commit e83c6e4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 6 deletions.
34 changes: 28 additions & 6 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,28 @@ def _args_for_request_with_retries(
url,
)

params = params or {}
if params and (method == "get" or method == "delete"):
# if we're sending params in the querystring, then we have to make sure we're not
# duplicating anything we got back from the server already (like in a list iterator)
# so, we parse the querystring the server sends back so we can merge with what we (or the user) are trying to send
# note: server sends back "expand[]" but users supply "expand", so we have to match them up
existing_params = {
"expand" if k == "expand[]" else k: v
for k, v in parse_qs(urlsplit(url).query).items()
}
# if a user is expanding something that wasn't expanded before, add (and deduplicate) it
if "expand" in existing_params and "expand" in params:
params["expand"] = list( # type:ignore - this is a dict
set([*existing_params["expand"], *params["expand"]])
)

params = {
**existing_params,
# user_supplied params take precedence over server params
**params,
}

encoded_params = urlencode(list(_api_encode(params or {}, api_mode)))

# Don't use strict form encoding by changing the square bracket control
Expand Down Expand Up @@ -586,13 +608,13 @@ def _args_for_request_with_retries(

if method == "get" or method == "delete":
if params:
query = encoded_params
scheme, netloc, path, base_query, fragment = urlsplit(abs_url)
# if we're sending query params, we've already merged the incoming ones with the server's "url"
# so we can overwrite the whole thing
scheme, netloc, path, _, fragment = urlsplit(abs_url)

if base_query:
query = "%s&%s" % (base_query, query)

abs_url = urlunsplit((scheme, netloc, path, query, fragment))
abs_url = urlunsplit(
(scheme, netloc, path, encoded_params, fragment)
)
post_data = None
elif method == "post":
if (
Expand Down
53 changes: 53 additions & 0 deletions tests/api_resources/test_list_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import stripe
from tests.http_client_mock import HTTPClientMock


class TestListObject(object):
Expand Down Expand Up @@ -439,6 +440,58 @@ def test_forwards_api_key_to_nested_resources(self, http_client_mock):
)
assert lo.data[0].api_key == "sk_test_iter_forwards_options"

def test_iter_with_params(self, http_client_mock: HTTPClientMock):
http_client_mock.stub_request(
"get",
path="/v1/invoices/upcoming/lines",
query_string="customer=cus_123&expand[0]=data.price&limit=1",
rbody=json.dumps(
{
"object": "list",
"data": [
{
"id": "prod_001",
"object": "product",
"price": {"object": "price", "id": "price_123"},
}
],
"url": "/v1/invoices/upcoming/lines?customer=cus_123&expand%5B%5D=data.price",
"has_more": True,
}
),
)
# second page
http_client_mock.stub_request(
"get",
path="/v1/invoices/upcoming/lines",
query_string="customer=cus_123&expand[0]=data.price&limit=1&starting_after=prod_001",
rbody=json.dumps(
{
"object": "list",
"data": [
{
"id": "prod_002",
"object": "product",
"price": {"object": "price", "id": "price_123"},
}
],
"url": "/v1/invoices/upcoming/lines?customer=cus_123&expand%5B%5D=data.price",
"has_more": False,
}
),
)

lo = stripe.Invoice.upcoming_lines(
api_key="sk_test_invoice_lines",
customer="cus_123",
expand=["data.price"],
limit=1,
)

seen = [item["id"] for item in lo.auto_paging_iter()]

assert seen == ["prod_001", "prod_002"]


class TestAutoPagingAsync:
@staticmethod
Expand Down

0 comments on commit e83c6e4

Please sign in to comment.