Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ListObject.auto_paging_iter() implement AsyncIterator #1247

Merged
merged 8 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions stripe/_any_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import TypeVar, Iterator, AsyncIterator

T = TypeVar("T")


class AnyIterator(Iterator[T], AsyncIterator[T]):
"""
AnyIterator supports iteration through both `for ... in <AnyIterator>` and `async for ... in <AnyIterator> syntaxes.
"""

def __init__(
self, iterator: Iterator[T], async_iterator: AsyncIterator[T]
) -> None:
self._iterator = iterator
self._async_iterator = async_iterator

self._sync_iterated = False
self._async_iterated = False

def __next__(self) -> T:
if self._async_iterated:
raise RuntimeError(
"AnyIterator error: cannot mix sync and async iteration"
)
self._sync_iterated = True
return self._iterator.__next__()

async def __anext__(self) -> T:
if self._sync_iterated:
raise RuntimeError(
"AnyIterator error: cannot mix sync and async iteration"
)
self._async_iterated = True
return await self._async_iterator.__anext__()
4 changes: 1 addition & 3 deletions stripe/_charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3918,9 +3918,7 @@ def search_auto_paging_iter(
async def search_auto_paging_iter_async(
cls, *args, **kwargs: Unpack["Charge.SearchParams"]
) -> AsyncIterator["Charge"]:
return (
await cls.search_async(*args, **kwargs)
).auto_paging_iter_async()
return (await cls.search_async(*args, **kwargs)).auto_paging_iter()

def mark_as_fraudulent(self, idempotency_key=None) -> "Charge":
params = {
Expand Down
4 changes: 1 addition & 3 deletions stripe/_customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,9 +2164,7 @@ def search_auto_paging_iter(
async def search_auto_paging_iter_async(
cls, *args, **kwargs: Unpack["Customer.SearchParams"]
) -> AsyncIterator["Customer"]:
return (
await cls.search_async(*args, **kwargs)
).auto_paging_iter_async()
return (await cls.search_async(*args, **kwargs)).auto_paging_iter()

@classmethod
def create_balance_transaction(
Expand Down
4 changes: 1 addition & 3 deletions stripe/_invoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10098,9 +10098,7 @@ def search_auto_paging_iter(
async def search_auto_paging_iter_async(
cls, *args, **kwargs: Unpack["Invoice.SearchParams"]
) -> AsyncIterator["Invoice"]:
return (
await cls.search_async(*args, **kwargs)
).auto_paging_iter_async()
return (await cls.search_async(*args, **kwargs)).auto_paging_iter()

@classmethod
def list_payments(
Expand Down
12 changes: 10 additions & 2 deletions stripe/_list_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from stripe._api_requestor import (
_APIRequestor, # pyright: ignore[reportPrivateUsage]
)
from stripe._any_iterator import AnyIterator
from stripe._stripe_object import StripeObject
from stripe._request_options import RequestOptions, extract_options_from_dict

from urllib.parse import quote_plus


T = TypeVar("T", bound=StripeObject)


Expand Down Expand Up @@ -123,7 +125,13 @@ def __len__(self) -> int:
def __reversed__(self) -> Iterator[T]: # pyright: ignore (see above)
return getattr(self, "data", []).__reversed__()

def auto_paging_iter(self) -> Iterator[T]:
def auto_paging_iter(self) -> AnyIterator[T]:
return AnyIterator(
self._auto_paging_iter(),
self._auto_paging_iter_async(),
)

def _auto_paging_iter(self) -> Iterator[T]:
page = self

while True:
Expand All @@ -142,7 +150,7 @@ def auto_paging_iter(self) -> Iterator[T]:
if page.is_empty:
break

async def auto_paging_iter_async(self) -> AsyncIterator[T]:
async def _auto_paging_iter_async(self) -> AsyncIterator[T]:
page = self

while True:
Expand Down
4 changes: 1 addition & 3 deletions stripe/_payment_intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12953,9 +12953,7 @@ def search_auto_paging_iter(
async def search_auto_paging_iter_async(
cls, *args, **kwargs: Unpack["PaymentIntent.SearchParams"]
) -> AsyncIterator["PaymentIntent"]:
return (
await cls.search_async(*args, **kwargs)
).auto_paging_iter_async()
return (await cls.search_async(*args, **kwargs)).auto_paging_iter()

_inner_class_types = {
"amount_details": AmountDetails,
Expand Down
4 changes: 1 addition & 3 deletions stripe/_price.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,7 @@ def search_auto_paging_iter(
async def search_auto_paging_iter_async(
cls, *args, **kwargs: Unpack["Price.SearchParams"]
) -> AsyncIterator["Price"]:
return (
await cls.search_async(*args, **kwargs)
).auto_paging_iter_async()
return (await cls.search_async(*args, **kwargs)).auto_paging_iter()

_inner_class_types = {
"currency_options": CurrencyOptions,
Expand Down
4 changes: 1 addition & 3 deletions stripe/_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,9 +871,7 @@ def search_auto_paging_iter(
async def search_auto_paging_iter_async(
cls, *args, **kwargs: Unpack["Product.SearchParams"]
) -> AsyncIterator["Product"]:
return (
await cls.search_async(*args, **kwargs)
).auto_paging_iter_async()
return (await cls.search_async(*args, **kwargs)).auto_paging_iter()

_inner_class_types = {
"features": Feature,
Expand Down
10 changes: 8 additions & 2 deletions stripe/_search_result_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from stripe import _util
import warnings
from stripe._request_options import RequestOptions, extract_options_from_dict
from stripe._any_iterator import AnyIterator

T = TypeVar("T", bound=StripeObject)

Expand Down Expand Up @@ -91,7 +92,7 @@ def __iter__(self) -> Iterator[T]: # pyright: ignore
def __len__(self) -> int:
return getattr(self, "data", []).__len__()

def auto_paging_iter(self) -> Iterator[T]:
def _auto_paging_iter(self) -> Iterator[T]:
page = self

while True:
Expand All @@ -102,7 +103,12 @@ def auto_paging_iter(self) -> Iterator[T]:
if page.is_empty:
break

async def auto_paging_iter_async(self) -> AsyncIterator[T]:
def auto_paging_iter(self) -> AnyIterator[T]:
return AnyIterator(
self._auto_paging_iter(), self._auto_paging_iter_async()
)

async def _auto_paging_iter_async(self) -> AsyncIterator[T]:
page = self

while True:
Expand Down
4 changes: 1 addition & 3 deletions stripe/_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -2976,9 +2976,7 @@ def search_auto_paging_iter(
async def search_auto_paging_iter_async(
cls, *args, **kwargs: Unpack["Subscription.SearchParams"]
) -> AsyncIterator["Subscription"]:
return (
await cls.search_async(*args, **kwargs)
).auto_paging_iter_async()
return (await cls.search_async(*args, **kwargs)).auto_paging_iter()

_inner_class_types = {
"automatic_tax": AutomaticTax,
Expand Down
6 changes: 3 additions & 3 deletions tests/api_resources/test_list_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ async def test_iter_one_page(self, http_client_mock):

http_client_mock.assert_no_request()

seen = [item["id"] async for item in lo.auto_paging_iter_async()]
seen = [item["id"] async for item in lo.auto_paging_iter()]

assert seen == ["pm_123", "pm_124"]

Expand All @@ -464,7 +464,7 @@ async def test_iter_two_pages(self, http_client_mock):
),
)

seen = [item["id"] async for item in lo.auto_paging_iter_async()]
seen = [item["id"] async for item in lo.auto_paging_iter()]

http_client_mock.assert_requested(
"get",
Expand All @@ -490,7 +490,7 @@ async def test_iter_reverse(self, http_client_mock):
),
)

seen = [item["id"] async for item in lo.auto_paging_iter_async()]
seen = [item["id"] async for item in lo.auto_paging_iter()]

http_client_mock.assert_requested(
"get",
Expand Down
4 changes: 2 additions & 2 deletions tests/api_resources/test_search_result_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ async def test_iter_one_page(self, http_client_mock):

http_client_mock.assert_no_request()

seen = [item["id"] async for item in sro.auto_paging_iter_async()]
seen = [item["id"] async for item in sro.auto_paging_iter()]

assert seen == ["pm_123", "pm_124"]

Expand All @@ -300,7 +300,7 @@ async def test_iter_two_pages(self, http_client_mock):
),
)

seen = [item["id"] async for item in sro.auto_paging_iter_async()]
seen = [item["id"] async for item in sro.auto_paging_iter()]

http_client_mock.assert_requested(
"get",
Expand Down
Loading