Skip to content

Commit

Permalink
Beta: async streaming (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardm-stripe authored Feb 14, 2024
1 parent bb58734 commit c4bc5d2
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 65 deletions.
1 change: 1 addition & 0 deletions flake8_stripe/flake8_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class TypingImportsChecker:
allowed_typing_imports = [
"Any",
"AsyncIterator",
"AsyncIterable",
"ClassVar",
"Optional",
"TypeVar",
Expand Down
1 change: 1 addition & 0 deletions stripe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def set_app_info(
from stripe._stripe_response import StripeResponseBase as StripeResponseBase
from stripe._stripe_response import (
StripeStreamResponse as StripeStreamResponse,
StripeStreamResponseAsync as StripeStreamResponseAsync,
)

# Error types
Expand Down
41 changes: 33 additions & 8 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import platform
from typing import (
Any,
AsyncIterable,
Dict,
List,
Mapping,
Expand All @@ -12,7 +13,12 @@
cast,
ClassVar,
)
from typing_extensions import TYPE_CHECKING, Literal, NoReturn, Unpack
from typing_extensions import (
TYPE_CHECKING,
Literal,
NoReturn,
Unpack,
)
import uuid
from urllib.parse import urlsplit, urlunsplit

Expand All @@ -33,7 +39,11 @@
_api_encode,
_json_encode_date_callback,
)
from stripe._stripe_response import StripeResponse, StripeStreamResponse
from stripe._stripe_response import (
StripeResponse,
StripeStreamResponse,
StripeStreamResponseAsync,
)
from stripe._request_options import RequestOptions, merge_options
from stripe._requestor_options import (
RequestorOptions,
Expand Down Expand Up @@ -276,7 +286,7 @@ async def request_stream_async(
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> StripeStreamResponse:
) -> StripeStreamResponseAsync:
stream, rcode, rheaders = await self.request_raw_async(
method.lower(),
url,
Expand All @@ -287,10 +297,8 @@ async def request_stream_async(
options=options,
_usage=_usage,
)
resp = self._interpret_streaming_response(
# TODO: should be able to remove this cast once self._client.request_stream_with_retries
# returns a more specific type.
cast(IOBase, stream),
resp = await self._interpret_streaming_response_async(
stream,
rcode,
rheaders,
)
Expand Down Expand Up @@ -654,7 +662,7 @@ async def request_raw_async(
base_address: BaseAddress,
api_mode: ApiMode,
_usage: Optional[List[str]] = None,
) -> Tuple[object, int, Mapping[str, str]]:
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
"""
Mechanism for issuing an API call
"""
Expand Down Expand Up @@ -819,6 +827,22 @@ def _interpret_response(
self.handle_error_response(rbody, rcode, resp.data, rheaders)
return resp

async def _interpret_streaming_response_async(
self,
stream: AsyncIterable[bytes],
rcode: int,
rheaders: Mapping[str, str],
) -> StripeStreamResponseAsync:
if self._should_handle_code_as_error(rcode):
json_content = b"".join([chunk async for chunk in stream])
self._interpret_response(json_content, rcode, rheaders)
# _interpret_response is guaranteed to throw since we've checked self._should_handle_code_as_error
raise RuntimeError(
"_interpret_response should have raised an error"
)
else:
return StripeStreamResponseAsync(stream, rcode, rheaders)

def _interpret_streaming_response(
self,
stream: IOBase,
Expand All @@ -838,6 +862,7 @@ def _interpret_streaming_response(
raise NotImplementedError(
"HTTP client %s does not return an IOBase object which "
"can be consumed when streaming a response."
% self._get_http_client().name
)

self._interpret_response(json_content, rcode, rheaders)
Expand Down
122 changes: 93 additions & 29 deletions stripe/_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Union,
cast,
overload,
AsyncIterable,
)
from typing_extensions import (
Literal,
Expand Down Expand Up @@ -418,11 +419,11 @@ def close(self):
class HTTPClientAsync(HTTPClientBase):
async def request_with_retries_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
max_network_retries=None,
max_network_retries: Optional[int] = None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
Expand All @@ -438,14 +439,14 @@ async def request_with_retries_async(

async def request_stream_with_retries_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
max_network_retries=None,
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Any]:
) -> Tuple[AsyncIterable[bytes], int, Any]:
return await self._request_with_retries_internal_async(
method,
url,
Expand All @@ -462,17 +463,45 @@ async def sleep_async(cls: Type[Self], secs: float) -> Awaitable[None]:
"HTTPClientAsync subclasses must implement `sleep`"
)

@overload
async def _request_with_retries_internal_async(
self,
method,
url,
headers,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming,
max_network_retries,
is_streaming: Literal[False],
max_network_retries: Optional[int],
*,
_usage=None
):
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Mapping[str, str]]:
...

@overload
async def _request_with_retries_internal_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming: Literal[True],
max_network_retries: Optional[int],
*,
_usage: Optional[List[str]] = None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
...

async def _request_with_retries_internal_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data,
is_streaming: bool,
max_network_retries: Optional[int],
*,
_usage: Optional[List[str]] = None
) -> Tuple[Any, int, Mapping[str, str]]:
self._add_telemetry_header(headers)

num_retries = 0
Expand Down Expand Up @@ -523,14 +552,18 @@ async def _request_with_retries_internal_async(
assert connection_error is not None
raise connection_error

async def request_async(self, method, url, headers, post_data=None):
async def request_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[bytes, int, Mapping[str, str]]:
raise NotImplementedError(
"HTTPClientAsync subclasses must implement `request`"
"HTTPClientAsync subclasses must implement `request_async`"
)

async def request_stream_async(self, method, url, headers, post_data=None):
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
raise NotImplementedError(
"HTTPClientAsync subclasses must implement `request_stream`"
"HTTPClientAsync subclasses must implement `request_stream_async`"
)

async def close_async(self):
Expand Down Expand Up @@ -1189,21 +1222,34 @@ def __init__(
def sleep_async(self, secs):
return self.anyio.sleep(secs)

async def request_async(
self, method, url, headers, post_data=None, timeout=80.0
) -> Tuple[bytes, int, Mapping[str, str]]:
def _get_request_args_kwargs(
self, method: str, url: str, headers: Mapping[str, str], post_data
):
kwargs = {}

if self._proxy:
kwargs["proxies"] = self._proxy

if self._timeout:
kwargs["timeout"] = self._timeout
return [
(method, url),
{"headers": headers, "data": post_data or {}, **kwargs},
]

async def request_async(
self,
method: str,
url: str,
headers: Mapping[str, str],
post_data=None,
timeout: float = 80.0,
) -> Tuple[bytes, int, Mapping[str, str]]:
args, kwargs = self._get_request_args_kwargs(
method, url, headers, post_data
)
try:
response = await self._client.request(
method, url, headers=headers, data=post_data or {}, **kwargs
)
response = await self._client.request(*args, **kwargs)
except Exception as e:
self._handle_request_error(e)

Expand All @@ -1223,8 +1269,24 @@ def _handle_request_error(self, e) -> NoReturn:
msg = textwrap.fill(msg) + "\n\n(Network error: %s)" % (err,)
raise APIConnectionError(msg, should_retry=should_retry)

async def request_stream_async(self, method, url, headers, post_data=None):
raise NotImplementedError()
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[AsyncIterable[bytes], int, Mapping[str, str]]:
args, kwargs = self._get_request_args_kwargs(
method, url, headers, post_data
)
try:
response = await self._client.send(
request=self._client.build_request(*args, **kwargs),
stream=True,
)
except Exception as e:
self._handle_request_error(e)
content = response.aiter_bytes()
status_code = response.status_code
headers = response.headers

return content, status_code, headers

async def close(self):
await self._client.aclose()
Expand All @@ -1246,11 +1308,13 @@ def raise_async_client_import_error() -> Never:
)

async def request_async(
self, method, url, headers, post_data=None
self, method: str, url: str, headers: Mapping[str, str], post_data=None
) -> Tuple[bytes, int, Mapping[str, str]]:
self.raise_async_client_import_error()

async def request_stream_async(self, method, url, headers, post_data=None):
async def request_stream_async(
self, method: str, url: str, headers: Mapping[str, str], post_data=None
):
self.raise_async_client_import_error()

async def close_async(self):
Expand Down
22 changes: 10 additions & 12 deletions stripe/_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from stripe._nested_resource_class_methods import nested_resource_class_methods
from stripe._request_options import RequestOptions
from stripe._stripe_object import StripeObject
from stripe._stripe_response import StripeStreamResponseAsync
from stripe._updateable_api_resource import UpdateableAPIResource
from stripe._util import class_method_variant, sanitize_id
from typing import Any, ClassVar, Dict, List, Optional, cast, overload
Expand Down Expand Up @@ -4992,14 +4993,16 @@ async def _cls_pdf_async(
@staticmethod
async def pdf_async(
quote: str, **params: Unpack["Quote.PdfParams"]
) -> Any:
) -> StripeStreamResponseAsync:
"""
Download the PDF for a finalized quote
"""
...

@overload
async def pdf_async(self, **params: Unpack["Quote.PdfParams"]) -> Any:
async def pdf_async(
self, **params: Unpack["Quote.PdfParams"]
) -> StripeStreamResponseAsync:
"""
Download the PDF for a finalized quote
"""
Expand All @@ -5008,19 +5011,14 @@ async def pdf_async(self, **params: Unpack["Quote.PdfParams"]) -> Any:
@class_method_variant("_cls_pdf_async")
async def pdf_async( # pyright: ignore[reportGeneralTypeIssues]
self, **params: Unpack["Quote.PdfParams"]
) -> Any:
) -> StripeStreamResponseAsync:
"""
Download the PDF for a finalized quote
"""
return cast(
Any,
await self._request_stream_async(
"get",
"/v1/quotes/{quote}/pdf".format(
quote=sanitize_id(self.get("id"))
),
params=params,
),
return await self._request_stream_async(
"get",
"/v1/quotes/{quote}/pdf".format(quote=sanitize_id(self.get("id"))),
params=params,
)

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions stripe/_stripe_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import stripe # noqa: IMP101
from stripe import _util

from stripe._stripe_response import StripeResponse, StripeStreamResponse
from stripe._stripe_response import (
StripeResponse,
StripeStreamResponse,
StripeStreamResponseAsync,
)
from stripe._encode import _encode_datetime # pyright: ignore
from stripe._request_options import extract_options_from_dict
from stripe._api_mode import ApiMode
Expand Down Expand Up @@ -471,7 +475,7 @@ async def _request_stream_async(
*,
base_address: BaseAddress = "api",
api_mode: ApiMode = "V1",
) -> StripeStreamResponse:
) -> StripeStreamResponseAsync:
if params is None:
params = self._retrieve_params

Expand Down
Loading

0 comments on commit c4bc5d2

Please sign in to comment.