|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import TYPE_CHECKING, Any |
| 3 | +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Generator, Iterable, Iterator |
| 4 | +from typing import ( |
| 5 | + TYPE_CHECKING, |
| 6 | + Any, |
| 7 | + Generic, |
| 8 | + Protocol, |
| 9 | + TypeVar, |
| 10 | +) |
4 | 11 |
|
5 | 12 | from apify_client._logging import WithLogDetailsClient |
| 13 | +from apify_client._types import ListPage |
6 | 14 | from apify_client._utils import to_safe_id |
7 | 15 |
|
8 | 16 | # Conditional import only executed when type checking, otherwise we'd get circular dependency issues |
9 | 17 | if TYPE_CHECKING: |
10 | 18 | from apify_client import ApifyClient, ApifyClientAsync |
11 | 19 | from apify_client._http_client import HTTPClient, HTTPClientAsync |
| 20 | +T = TypeVar('T') |
12 | 21 |
|
13 | 22 |
|
14 | 23 | class _BaseBaseClient(metaclass=WithLogDetailsClient): |
@@ -87,6 +96,42 @@ def __init__( |
87 | 96 | self.safe_id = to_safe_id(self.resource_id) |
88 | 97 | self.url = f'{self.url}/{self.safe_id}' |
89 | 98 |
|
| 99 | + @staticmethod |
| 100 | + def _list_iterable_from_callback(callback: Callable[..., ListPage[T]], **kwargs: Any) -> ListPageProtocol[T]: |
| 101 | + """Return object can be awaited or iterated over. |
| 102 | +
|
| 103 | + Not using total from the API response as it can change during iteration. |
| 104 | + """ |
| 105 | + chunk_size = kwargs.pop('chunk_size', 0) or 0 |
| 106 | + offset = kwargs.get('offset') or 0 |
| 107 | + limit = kwargs.get('limit') or 0 |
| 108 | + |
| 109 | + list_page = callback(**{**kwargs, 'limit': _min_for_limit_param(kwargs.get('limit'), chunk_size)}) |
| 110 | + |
| 111 | + def iterator() -> Iterator[T]: |
| 112 | + current_page = list_page |
| 113 | + for item in current_page.items: |
| 114 | + yield item |
| 115 | + |
| 116 | + fetched_items = len(current_page.items) |
| 117 | + while ( |
| 118 | + current_page.items # If there are any items left to fetch |
| 119 | + and ( |
| 120 | + not limit or (limit > fetched_items) |
| 121 | + ) # If there are is limit to fetch and have not reached it yet. |
| 122 | + ): |
| 123 | + new_kwargs = { |
| 124 | + **kwargs, |
| 125 | + 'offset': offset + fetched_items, |
| 126 | + 'limit': chunk_size if not limit else _min_for_limit_param(limit - fetched_items, chunk_size), |
| 127 | + } |
| 128 | + current_page = callback(**new_kwargs) |
| 129 | + for item in current_page.items: |
| 130 | + yield item |
| 131 | + fetched_items += len(current_page.items) |
| 132 | + |
| 133 | + return IterableListPage[T](list_page, iterator()) |
| 134 | + |
90 | 135 |
|
91 | 136 | class BaseClientAsync(_BaseBaseClient): |
92 | 137 | """Base class for async sub-clients.""" |
@@ -127,3 +172,114 @@ def __init__( |
127 | 172 | if self.resource_id is not None: |
128 | 173 | self.safe_id = to_safe_id(self.resource_id) |
129 | 174 | self.url = f'{self.url}/{self.safe_id}' |
| 175 | + |
| 176 | + @staticmethod |
| 177 | + def _list_iterable_from_callback( |
| 178 | + callback: Callable[..., Awaitable[ListPage[T]]], **kwargs: Any |
| 179 | + ) -> ListPageProtocolAsync[T]: |
| 180 | + """Return object can be awaited or iterated over. |
| 181 | +
|
| 182 | + Not using total from the API response as it can change during iteration. |
| 183 | + """ |
| 184 | + chunk_size = kwargs.pop('chunk_size', 0) or 0 |
| 185 | + offset = kwargs.get('offset') or 0 |
| 186 | + limit = kwargs.get('limit') or 0 |
| 187 | + |
| 188 | + list_page_awaitable = callback(**{**kwargs, 'limit': _min_for_limit_param(kwargs.get('limit'), chunk_size)}) |
| 189 | + |
| 190 | + async def async_iterator() -> AsyncIterator[T]: |
| 191 | + current_page = await list_page_awaitable |
| 192 | + for item in current_page.items: |
| 193 | + yield item |
| 194 | + |
| 195 | + fetched_items = len(current_page.items) |
| 196 | + while ( |
| 197 | + current_page.items # If there are any items left to fetch |
| 198 | + and ( |
| 199 | + not limit or (limit > fetched_items) |
| 200 | + ) # If there are is limit to fetch and have not reached it yet. |
| 201 | + ): |
| 202 | + new_kwargs = { |
| 203 | + **kwargs, |
| 204 | + 'offset': offset + fetched_items, |
| 205 | + 'limit': chunk_size if not limit else _min_for_limit_param(limit - fetched_items, chunk_size), |
| 206 | + } |
| 207 | + current_page = await callback(**new_kwargs) |
| 208 | + for item in current_page.items: |
| 209 | + yield item |
| 210 | + fetched_items += len(current_page.items) |
| 211 | + |
| 212 | + return IterableListPageAsync[T](list_page_awaitable, async_iterator()) |
| 213 | + |
| 214 | + |
| 215 | +def _min_for_limit_param(a: int | None, b: int | None) -> int | None: |
| 216 | + """Return minimum of two limit parameters, treating None or 0 as infinity. Return None for infinity.""" |
| 217 | + # API treats 0 as None for limit parameter, in this context API understands 0 as infinity. |
| 218 | + if a == 0: |
| 219 | + a = None |
| 220 | + if b == 0: |
| 221 | + b = None |
| 222 | + if a is None: |
| 223 | + return b |
| 224 | + if b is None: |
| 225 | + return a |
| 226 | + return min(a, b) |
| 227 | + |
| 228 | + |
| 229 | +class ListPageProtocol(Iterable[T], Protocol[T]): |
| 230 | + """Protocol for an object that can be both awaited and asynchronously iterated over.""" |
| 231 | + |
| 232 | + items: list[T] |
| 233 | + """List of returned objects on this page.""" |
| 234 | + |
| 235 | + count: int |
| 236 | + """Count of the returned objects on this page.""" |
| 237 | + |
| 238 | + offset: int |
| 239 | + """The limit on the number of returned objects offset specified in the API call.""" |
| 240 | + |
| 241 | + limit: int |
| 242 | + """The offset of the first object specified in the API call""" |
| 243 | + |
| 244 | + total: int |
| 245 | + """Total number of objects matching the API call criteria.""" |
| 246 | + |
| 247 | + desc: bool |
| 248 | + """Whether the listing is descending or not.""" |
| 249 | + |
| 250 | + |
| 251 | +class ListPageProtocolAsync(AsyncIterable[T], Awaitable[ListPage[T]], Protocol[T]): |
| 252 | + """Protocol for an object that can be both awaited and asynchronously iterated over.""" |
| 253 | + |
| 254 | + |
| 255 | +class IterableListPage(ListPage[T], Generic[T]): |
| 256 | + """Can be called to get ListPage with items or iterated over to get individual items.""" |
| 257 | + |
| 258 | + def __init__(self, list_page: ListPage[T], iterator: Iterator[T]) -> None: |
| 259 | + self.items = list_page.items |
| 260 | + self.offset = list_page.offset |
| 261 | + self.limit = list_page.limit |
| 262 | + self.count = list_page.count |
| 263 | + self.total = list_page.total |
| 264 | + self.desc = list_page.desc |
| 265 | + self._iterator = iterator |
| 266 | + |
| 267 | + def __iter__(self) -> Iterator[T]: |
| 268 | + """Return an iterator over the items from API, possibly doing multiple API calls.""" |
| 269 | + return self._iterator |
| 270 | + |
| 271 | + |
| 272 | +class IterableListPageAsync(Generic[T]): |
| 273 | + """Can be awaited to get ListPage with items or asynchronously iterated over to get individual items.""" |
| 274 | + |
| 275 | + def __init__(self, awaitable: Awaitable[ListPage[T]], async_iterator: AsyncIterator[T]) -> None: |
| 276 | + self._awaitable = awaitable |
| 277 | + self._async_iterator = async_iterator |
| 278 | + |
| 279 | + def __aiter__(self) -> AsyncIterator[T]: |
| 280 | + """Return an asynchronous iterator over the items from API, possibly doing multiple API calls.""" |
| 281 | + return self._async_iterator |
| 282 | + |
| 283 | + def __await__(self) -> Generator[Any, Any, ListPage[T]]: |
| 284 | + """Return an awaitable that resolves to the ListPage doing exactly one API call.""" |
| 285 | + return self._awaitable.__await__() |
0 commit comments