Skip to content

Commit

Permalink
feat(client): support parsing custom response types (#1111)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Jan 29, 2024
1 parent fa63e60 commit da00fc3
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 79 deletions.
2 changes: 2 additions & 0 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._types import NoneType, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
from ._models import BaseModel
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._exceptions import (
Expand Down Expand Up @@ -59,6 +60,7 @@
"OpenAI",
"AsyncOpenAI",
"file_from_path",
"BaseModel",
]

from .lib import azure as _azure
Expand Down
102 changes: 70 additions & 32 deletions src/openai/_legacy_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,28 @@
import logging
import datetime
import functools
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast
from typing_extensions import Awaitable, ParamSpec, get_args, override, deprecated, get_origin
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin

import anyio
import httpx
import pydantic

from ._types import NoneType
from ._utils import is_given
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
from ._exceptions import APIResponseValidationError

if TYPE_CHECKING:
from ._models import FinalRequestOptions
from ._base_client import Stream, BaseClient, AsyncStream
from ._base_client import BaseClient


P = ParamSpec("P")
R = TypeVar("R")
_T = TypeVar("_T")

log: logging.Logger = logging.getLogger(__name__)

Expand All @@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]):

_cast_to: type[R]
_client: BaseClient[Any, Any]
_parsed: R | None
_parsed_by_type: dict[type[Any], Any]
_stream: bool
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
_options: FinalRequestOptions
Expand All @@ -62,27 +65,60 @@ def __init__(
) -> None:
self._cast_to = cast_to
self._client = client
self._parsed = None
self._parsed_by_type = {}
self._stream = stream
self._stream_cls = stream_cls
self._options = options
self.http_response = raw

@overload
def parse(self, *, to: type[_T]) -> _T:
...

@overload
def parse(self) -> R:
...

def parse(self, *, to: type[_T] | None = None) -> R | _T:
"""Returns the rich python representation of this response's data.
NOTE: For the async client: this will become a coroutine in the next major version.
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
NOTE: For the async client: this will become a coroutine in the next major version.
You can customise the type that the response is parsed into through
the `to` argument, e.g.
```py
from openai import BaseModel
class MyModel(BaseModel):
foo: str
obj = response.parse(to=MyModel)
print(obj.foo)
```
We support parsing:
- `BaseModel`
- `dict`
- `list`
- `Union`
- `str`
- `httpx.Response`
"""
if self._parsed is not None:
return self._parsed
cache_key = to if to is not None else self._cast_to
cached = self._parsed_by_type.get(cache_key)
if cached is not None:
return cached # type: ignore[no-any-return]

parsed = self._parse()
parsed = self._parse(to=to)
if is_given(self._options.post_parser):
parsed = self._options.post_parser(parsed)

self._parsed = parsed
self._parsed_by_type[cache_key] = parsed
return parsed

@property
Expand Down Expand Up @@ -135,13 +171,29 @@ def elapsed(self) -> datetime.timedelta:
"""The time taken for the complete request/response cycle to complete."""
return self.http_response.elapsed

def _parse(self) -> R:
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
if self._stream:
if to:
if not is_stream_class_type(to):
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")

return cast(
_T,
to(
cast_to=extract_stream_chunk_type(
to,
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
),
response=self.http_response,
client=cast(Any, self._client),
),
)

if self._stream_cls:
return cast(
R,
self._stream_cls(
cast_to=_extract_stream_chunk_type(self._stream_cls),
cast_to=extract_stream_chunk_type(self._stream_cls),
response=self.http_response,
client=cast(Any, self._client),
),
Expand All @@ -160,7 +212,7 @@ def _parse(self) -> R:
),
)

cast_to = self._cast_to
cast_to = to if to is not None else self._cast_to
if cast_to is NoneType:
return cast(R, None)

Expand All @@ -186,14 +238,9 @@ def _parse(self) -> R:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(R, response)

# The check here is necessary as we are subverting the the type system
# with casts as the relationship between TypeVars and Types are very strict
# which means we must return *exactly* what was input or transform it in a
# way that retains the TypeVar state. As we cannot do that in this function
# then we have to resort to using `cast`. At the time of writing, we know this
# to be safe as we have handled all the types that could be bound to the
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")

if (
cast_to is not object
and not origin is list
Expand All @@ -202,12 +249,12 @@ def _parse(self) -> R:
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
)

# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type").split(";")
content_type, *_ = response.headers.get("content-type", "*").split(";")
if content_type != "application/json":
if is_basemodel(cast_to):
try:
Expand Down Expand Up @@ -253,15 +300,6 @@ def __init__(self) -> None:
)


def _extract_stream_chunk_type(stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])


def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
"""Higher order function that takes one of our bound API methods and wraps it
to support returning the raw `APIResponse` object directly.
Expand Down
Loading

0 comments on commit da00fc3

Please sign in to comment.