Skip to content

Commit

Permalink
chore(internal): minor options / compat functions updates (#1549)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored and stainless-app[bot] committed Jul 16, 2024
1 parent f14f859 commit 83ebf66
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
12 changes: 6 additions & 6 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,9 @@ def __exit__(
def _prepare_options(
self,
options: FinalRequestOptions, # noqa: ARG002
) -> None:
) -> FinalRequestOptions:
"""Hook for mutating the given options"""
return None
return options

def _prepare_request(
self,
Expand Down Expand Up @@ -962,7 +962,7 @@ def _request(
input_options = model_copy(options)

cast_to = self._maybe_override_cast_to(cast_to, options)
self._prepare_options(options)
options = self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
Expand Down Expand Up @@ -1457,9 +1457,9 @@ async def __aexit__(
async def _prepare_options(
self,
options: FinalRequestOptions, # noqa: ARG002
) -> None:
) -> FinalRequestOptions:
"""Hook for mutating the given options"""
return None
return options

async def _prepare_request(
self,
Expand Down Expand Up @@ -1544,7 +1544,7 @@ async def _request(
input_options = model_copy(options)

cast_to = self._maybe_override_cast_to(cast_to, options)
await self._prepare_options(options)
options = await self._prepare_options(options)

retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
Expand Down
6 changes: 3 additions & 3 deletions src/openai/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
return model.__fields__ # type: ignore


def model_copy(model: _ModelT) -> _ModelT:
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
if PYDANTIC_V2:
return model.model_copy()
return model.copy() # type: ignore
return model.model_copy(deep=deep)
return model.copy(deep=deep) # type: ignore


def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
Expand Down
13 changes: 9 additions & 4 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .._types import NOT_GIVEN, Omit, Timeout, NotGiven
from .._utils import is_given, is_mapping
from .._client import OpenAI, AsyncOpenAI
from .._compat import model_copy
from .._models import FinalRequestOptions
from .._streaming import Stream, AsyncStream
from .._exceptions import OpenAIError
Expand Down Expand Up @@ -281,8 +282,10 @@ def _get_azure_ad_token(self) -> str | None:
return None

@override
def _prepare_options(self, options: FinalRequestOptions) -> None:
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}

options = model_copy(options)
options.headers = headers

azure_ad_token = self._get_azure_ad_token()
Expand All @@ -296,7 +299,7 @@ def _prepare_options(self, options: FinalRequestOptions) -> None:
# should never be hit
raise ValueError("Unable to handle auth")

return super()._prepare_options(options)
return options


class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
Expand Down Expand Up @@ -524,8 +527,10 @@ async def _get_azure_ad_token(self) -> str | None:
return None

@override
async def _prepare_options(self, options: FinalRequestOptions) -> None:
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}

options = model_copy(options)
options.headers = headers

azure_ad_token = await self._get_azure_ad_token()
Expand All @@ -539,4 +544,4 @@ async def _prepare_options(self, options: FinalRequestOptions) -> None:
# should never be hit
raise ValueError("Unable to handle auth")

return await super()._prepare_options(options)
return options

0 comments on commit 83ebf66

Please sign in to comment.