Skip to content

Commit

Permalink
add more infra
Browse files Browse the repository at this point in the history
  • Loading branch information
richardm-stripe committed Feb 2, 2024
1 parent 83bca93 commit cbb0170
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 25 deletions.
61 changes: 61 additions & 0 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,37 @@ def request(
api_mode=api_mode,
)

async def request_async(
self,
method: str,
url: str,
params: Optional[Mapping[str, Any]] = None,
options: Optional[RequestOptions] = None,
*,
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> "StripeObject":
requestor = self._replace_options(options)
rbody, rcode, rheaders = await requestor.request_raw_async(
method.lower(),
url,
params,
is_streaming=False,
api_mode=api_mode,
base_address=base_address,
options=options,
_usage=_usage,
)
resp = requestor._interpret_response(rbody, rcode, rheaders)

return _convert_to_stripe_object(
resp=resp,
params=params,
requestor=requestor,
api_mode=api_mode,
)

def request_stream(
self,
method: str,
Expand Down Expand Up @@ -235,6 +266,36 @@ def request_stream(
)
return resp

async def request_stream_async(
self,
method: str,
url: str,
params: Optional[Mapping[str, Any]] = None,
options: Optional[RequestOptions] = None,
*,
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> StripeStreamResponse:
stream, rcode, rheaders = await self.request_raw_async(
method.lower(),
url,
params,
is_streaming=True,
api_mode=api_mode,
base_address=base_address,
options=options,
_usage=_usage,
)
resp = self._interpret_streaming_response(
# TODO: should be able to remove this cast once self._client.request_stream_with_retries
# returns a more specific type.
cast(IOBase, stream),
rcode,
rheaders,
)
return resp

def handle_error_response(self, rbody, rcode, resp, rheaders) -> NoReturn:
try:
error_data = resp["error"]
Expand Down
122 changes: 103 additions & 19 deletions stripe/_api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def retrieve(cls, id, **params) -> T:
def refresh(self) -> Self:
return self._request_and_refresh("get", self.instance_url())

async def refresh_async(self) -> Self:
return await self._request_and_refresh_async(
"get", self.instance_url()
)

@classmethod
def class_url(cls) -> str:
if cls == APIResource:
Expand Down Expand Up @@ -64,21 +69,43 @@ def instance_url(self) -> str:
extn = quote_plus(id)
return "%s/%s" % (base, extn)

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
def _request(
self,
method_,
url_,
method,
url,
params=None,
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
) -> StripeObject:
obj = StripeObject._request(
self,
method_,
url_,
method,
url,
params=params,
base_address=base_address,
api_mode=api_mode,
)

if type(self) is type(obj):
self._refresh_from(values=obj, api_mode=api_mode)
return self
else:
return obj

async def _request_async(
self,
method,
url,
params=None,
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
) -> StripeObject:
obj = await StripeObject._request_async(
self,
method,
url,
params=params,
base_address=base_address,
api_mode=api_mode,
Expand All @@ -90,12 +117,10 @@ def _request(
else:
return obj

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
def _request_and_refresh(
self,
method_: Literal["get", "post", "delete"],
url_: str,
method: Literal["get", "post", "delete"],
url: str,
params: Optional[Mapping[str, Any]] = None,
_usage: Optional[List[str]] = None,
*,
Expand All @@ -104,8 +129,31 @@ def _request_and_refresh(
) -> Self:
obj = StripeObject._request(
self,
method_,
url_,
method,
url,
params=params,
base_address=base_address,
api_mode=api_mode,
_usage=_usage,
)

self._refresh_from(values=obj, api_mode=api_mode)
return self

async def _request_and_refresh_async(
self,
method: Literal["get", "post", "delete"],
url: str,
params: Optional[Mapping[str, Any]] = None,
_usage: Optional[List[str]] = None,
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
) -> Self:
obj = await StripeObject._request_async(
self,
method,
url,
params=params,
base_address=base_address,
api_mode=api_mode,
Expand All @@ -115,8 +163,6 @@ def _request_and_refresh(
self._refresh_from(values=obj, api_mode=api_mode)
return self

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
@classmethod
def _static_request(
cls,
Expand All @@ -137,24 +183,62 @@ def _static_request(
api_mode=api_mode,
)

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
@classmethod
def _static_request_stream(
async def _static_request_async(
cls,
method_,
url_,
params=None,
params: Optional[Mapping[str, Any]] = None,
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
):
request_options, request_params = extract_options_from_dict(params)
return _APIRequestor._global_instance().request_stream(
return await _APIRequestor._global_instance().request_async(
method_,
url_,
params=request_params,
options=request_options,
base_address=base_address,
api_mode=api_mode,
)

@classmethod
def _static_request_stream(
cls,
method,
url,
params=None,
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
):
request_options, request_params = extract_options_from_dict(params)
return _APIRequestor._global_instance().request_stream(
method,
url,
params=request_params,
options=request_options,
base_address=base_address,
api_mode=api_mode,
)

@classmethod
async def _static_request_stream_async(
cls,
method,
url,
params=None,
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
):
request_options, request_params = extract_options_from_dict(params)
return await _APIRequestor._global_instance().request_stream_async(
method,
url,
params=request_params,
options=request_options,
base_address=base_address,
api_mode=api_mode,
)
82 changes: 82 additions & 0 deletions stripe/_list_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import (
Any,
AsyncIterator,
Iterator,
List,
Generic,
Expand Down Expand Up @@ -46,6 +47,23 @@ def list(self, **params: Mapping[str, Any]) -> Self:
),
)

async def list_async(self, **params: Mapping[str, Any]) -> Self:
url = self.get("url")
if not isinstance(url, str):
raise ValueError(
'Cannot call .list on a list object without a string "url" property'
)
return cast(
Self,
await self._request_async(
"get",
url,
params=params,
base_address="api",
api_mode="V1",
),
)

def create(self, **params: Mapping[str, Any]) -> T:
url = self.get("url")
if not isinstance(url, str):
Expand Down Expand Up @@ -126,6 +144,25 @@ def auto_paging_iter(self) -> Iterator[T]:
if page.is_empty:
break

async def auto_paging_iter_async(self) -> AsyncIterator[T]:
page = self

while True:
if (
"ending_before" in self._retrieve_params
and "starting_after" not in self._retrieve_params
):
for item in reversed(page):
yield item
page = await page.previous_page_async()
else:
for item in page:
yield item
page = await page.next_page_async()

if page.is_empty:
break

@classmethod
def _empty_list(
cls,
Expand Down Expand Up @@ -165,6 +202,27 @@ def next_page(self, **params: Unpack[RequestOptions]) -> Self:
**params_with_filters,
)

async def next_page_async(self, **params: Unpack[RequestOptions]) -> Self:
if not self.has_more:
request_options, _ = extract_options_from_dict(params)
return self._empty_list(
**request_options,
)

last_id = getattr(self.data[-1], "id")
if not last_id:
raise ValueError(
"Unexpected: element in .data of list object had no id"
)

params_with_filters = dict(self._retrieve_params)
params_with_filters.update({"starting_after": last_id})
params_with_filters.update(params)

return await self.list_async(
**params_with_filters,
)

def previous_page(self, **params: Unpack[RequestOptions]) -> Self:
if not self.has_more:
request_options, _ = extract_options_from_dict(params)
Expand All @@ -186,3 +244,27 @@ def previous_page(self, **params: Unpack[RequestOptions]) -> Self:
**params_with_filters,
)
return result

async def previous_page_async(
self, **params: Unpack[RequestOptions]
) -> Self:
if not self.has_more:
request_options, _ = extract_options_from_dict(params)
return self._empty_list(
**request_options,
)

first_id = getattr(self.data[0], "id")
if not first_id:
raise ValueError(
"Unexpected: element in .data of list object had no id"
)

params_with_filters = dict(self._retrieve_params)
params_with_filters.update({"ending_before": first_id})
params_with_filters.update(params)

result = await self.list_async(
**params_with_filters,
)
return result
Loading

0 comments on commit cbb0170

Please sign in to comment.