Skip to content

Commit

Permalink
fix(client/async): avoid blocking io call for platform headers (#1488)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] authored and stainless-bot committed Jun 19, 2024
1 parent 6aa2a80 commit ae64c05
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
17 changes: 13 additions & 4 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
RequestOptions,
ModelBuilderProtocol,
)
from ._utils import is_dict, is_list, is_given, lru_cache, is_mapping
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
from ._compat import model_copy, model_dump
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import (
Expand Down Expand Up @@ -359,6 +359,7 @@ def __init__(
self._custom_query = custom_query or {}
self._strict_response_validation = _strict_response_validation
self._idempotency_header = None
self._platform: Platform | None = None

if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
raise TypeError(
Expand Down Expand Up @@ -623,7 +624,10 @@ def base_url(self, url: URL | str) -> None:
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))

def platform_headers(self) -> Dict[str, str]:
return platform_headers(self._version)
# the actual implementation is in a separate `lru_cache` decorated
# function because adding `lru_cache` to methods will leak memory
# https://github.com/python/cpython/issues/88476
return platform_headers(self._version, platform=self._platform)

def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:
"""Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.
Expand Down Expand Up @@ -1513,6 +1517,11 @@ async def _request(
stream_cls: type[_AsyncStreamT] | None,
remaining_retries: int | None,
) -> ResponseT | _AsyncStreamT:
if self._platform is None:
# `get_platform` can make blocking IO calls so we
# execute it earlier while we are in an async context
self._platform = await asyncify(get_platform)()

cast_to = self._maybe_override_cast_to(cast_to, options)
await self._prepare_options(options)

Expand Down Expand Up @@ -1949,11 +1958,11 @@ def get_platform() -> Platform:


@lru_cache(maxsize=None)
def platform_headers(version: str) -> Dict[str, str]:
def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]:
return {
"X-Stainless-Lang": "python",
"X-Stainless-Package-Version": version,
"X-Stainless-OS": str(get_platform()),
"X-Stainless-OS": str(platform or get_platform()),
"X-Stainless-Arch": str(get_architecture()),
"X-Stainless-Runtime": get_python_runtime(),
"X-Stainless-Runtime-Version": get_python_version(),
Expand Down
1 change: 1 addition & 0 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
from ._reflection import function_has_argument as function_has_argument
8 changes: 8 additions & 0 deletions src/openai/_utils/_reflection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import inspect
from typing import Any, Callable


def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
"""Returns whether or not the given function has a specific parameter"""
sig = inspect.signature(func)
return arg_name in sig.parameters
19 changes: 18 additions & 1 deletion src/openai/_utils/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import anyio
import anyio.to_thread

from ._reflection import function_has_argument

T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")

Expand Down Expand Up @@ -59,6 +61,21 @@ def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str:

async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
partial_f = functools.partial(function, *args, **kwargs)
return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter)

# In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
# `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
# surfacing deprecation warnings.
if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
return await anyio.to_thread.run_sync(
partial_f,
abandon_on_cancel=cancellable,
limiter=limiter,
)

return await anyio.to_thread.run_sync(
partial_f,
cancellable=cancellable,
limiter=limiter,
)

return wrapper

0 comments on commit ae64c05

Please sign in to comment.