55import logging
66import datetime
77import functools
8- from typing import TYPE_CHECKING , Any , Union , Generic , TypeVar , Callable , Iterator , AsyncIterator , cast
9- from typing_extensions import Awaitable , ParamSpec , get_args , override , deprecated , get_origin
8+ from typing import TYPE_CHECKING , Any , Union , Generic , TypeVar , Callable , Iterator , AsyncIterator , cast , overload
9+ from typing_extensions import Awaitable , ParamSpec , override , deprecated , get_origin
1010
1111import anyio
1212import httpx
13+ import pydantic
1314
1415from ._types import NoneType
1516from ._utils import is_given
1617from ._models import BaseModel , is_basemodel
1718from ._constants import RAW_RESPONSE_HEADER
19+ from ._streaming import Stream , AsyncStream , is_stream_class_type , extract_stream_chunk_type
1820from ._exceptions import APIResponseValidationError
1921
2022if TYPE_CHECKING :
2123 from ._models import FinalRequestOptions
22- from ._base_client import Stream , BaseClient , AsyncStream
24+ from ._base_client import BaseClient
2325
2426
2527P = ParamSpec ("P" )
2628R = TypeVar ("R" )
29+ _T = TypeVar ("_T" )
2730
2831log : logging .Logger = logging .getLogger (__name__ )
2932
@@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]):
4346
4447 _cast_to : type [R ]
4548 _client : BaseClient [Any , Any ]
46- _parsed : R | None
49+ _parsed_by_type : dict [ type [ Any ], Any ]
4750 _stream : bool
4851 _stream_cls : type [Stream [Any ]] | type [AsyncStream [Any ]] | None
4952 _options : FinalRequestOptions
@@ -62,27 +65,60 @@ def __init__(
6265 ) -> None :
6366 self ._cast_to = cast_to
6467 self ._client = client
65- self ._parsed = None
68+ self ._parsed_by_type = {}
6669 self ._stream = stream
6770 self ._stream_cls = stream_cls
6871 self ._options = options
6972 self .http_response = raw
7073
74+ @overload
75+ def parse (self , * , to : type [_T ]) -> _T :
76+ ...
77+
78+ @overload
7179 def parse (self ) -> R :
80+ ...
81+
82+ def parse (self , * , to : type [_T ] | None = None ) -> R | _T :
7283 """Returns the rich python representation of this response's data.
7384
85+ NOTE: For the async client: this will become a coroutine in the next major version.
86+
7487 For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
7588
76- NOTE: For the async client: this will become a coroutine in the next major version.
89+ You can customise the type that the response is parsed into through
90+ the `to` argument, e.g.
91+
92+ ```py
93+ from openai import BaseModel
94+
95+
96+ class MyModel(BaseModel):
97+ foo: str
98+
99+
100+ obj = response.parse(to=MyModel)
101+ print(obj.foo)
102+ ```
103+
104+ We support parsing:
105+ - `BaseModel`
106+ - `dict`
107+ - `list`
108+ - `Union`
109+ - `str`
110+ - `httpx.Response`
77111 """
78- if self ._parsed is not None :
79- return self ._parsed
112+ cache_key = to if to is not None else self ._cast_to
113+ cached = self ._parsed_by_type .get (cache_key )
114+ if cached is not None :
115+ return cached # type: ignore[no-any-return]
80116
81- parsed = self ._parse ()
117+ parsed = self ._parse (to = to )
82118 if is_given (self ._options .post_parser ):
83119 parsed = self ._options .post_parser (parsed )
84120
85- self ._parsed = parsed
121+ self ._parsed_by_type [ cache_key ] = parsed
86122 return parsed
87123
88124 @property
@@ -135,13 +171,29 @@ def elapsed(self) -> datetime.timedelta:
135171 """The time taken for the complete request/response cycle to complete."""
136172 return self .http_response .elapsed
137173
138- def _parse (self ) -> R :
174+ def _parse (self , * , to : type [ _T ] | None = None ) -> R | _T :
139175 if self ._stream :
176+ if to :
177+ if not is_stream_class_type (to ):
178+ raise TypeError (f"Expected custom parse type to be a subclass of { Stream } or { AsyncStream } " )
179+
180+ return cast (
181+ _T ,
182+ to (
183+ cast_to = extract_stream_chunk_type (
184+ to ,
185+ failure_message = "Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]" ,
186+ ),
187+ response = self .http_response ,
188+ client = cast (Any , self ._client ),
189+ ),
190+ )
191+
140192 if self ._stream_cls :
141193 return cast (
142194 R ,
143195 self ._stream_cls (
144- cast_to = _extract_stream_chunk_type (self ._stream_cls ),
196+ cast_to = extract_stream_chunk_type (self ._stream_cls ),
145197 response = self .http_response ,
146198 client = cast (Any , self ._client ),
147199 ),
@@ -160,7 +212,7 @@ def _parse(self) -> R:
160212 ),
161213 )
162214
163- cast_to = self ._cast_to
215+ cast_to = to if to is not None else self ._cast_to
164216 if cast_to is NoneType :
165217 return cast (R , None )
166218
@@ -186,14 +238,9 @@ def _parse(self) -> R:
186238 raise ValueError (f"Subclasses of httpx.Response cannot be passed to `cast_to`" )
187239 return cast (R , response )
188240
189- # The check here is necessary as we are subverting the the type system
190- # with casts as the relationship between TypeVars and Types are very strict
191- # which means we must return *exactly* what was input or transform it in a
192- # way that retains the TypeVar state. As we cannot do that in this function
193- # then we have to resort to using `cast`. At the time of writing, we know this
194- # to be safe as we have handled all the types that could be bound to the
195- # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
196- # this function would become unsafe but a type checker would not report an error.
241+ if inspect .isclass (origin ) and not issubclass (origin , BaseModel ) and issubclass (origin , pydantic .BaseModel ):
242+ raise TypeError ("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`" )
243+
197244 if (
198245 cast_to is not object
199246 and not origin is list
@@ -202,12 +249,12 @@ def _parse(self) -> R:
202249 and not issubclass (origin , BaseModel )
203250 ):
204251 raise RuntimeError (
205- f"Invalid state , expected { cast_to } to be a subclass type of { BaseModel } , { dict } , { list } or { Union } ."
252+ f"Unsupported type , expected { cast_to } to be a subclass of { BaseModel } , { dict } , { list } , { Union } , { NoneType } , { str } or { httpx . Response } ."
206253 )
207254
208255 # split is required to handle cases where additional information is included
209256 # in the response, e.g. application/json; charset=utf-8
210- content_type , * _ = response .headers .get ("content-type" ).split (";" )
257+ content_type , * _ = response .headers .get ("content-type" , "*" ).split (";" )
211258 if content_type != "application/json" :
212259 if is_basemodel (cast_to ):
213260 try :
@@ -253,15 +300,6 @@ def __init__(self) -> None:
253300 )
254301
255302
256- def _extract_stream_chunk_type (stream_cls : type ) -> type :
257- args = get_args (stream_cls )
258- if not args :
259- raise TypeError (
260- f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received { stream_cls } " ,
261- )
262- return cast (type , args [0 ])
263-
264-
265303def to_raw_response_wrapper (func : Callable [P , R ]) -> Callable [P , LegacyAPIResponse [R ]]:
266304 """Higher order function that takes one of our bound API methods and wraps it
267305 to support returning the raw `APIResponse` object directly.
0 commit comments