1313from ._types import (
1414 NOT_GIVEN ,
1515 Omit ,
16+ Headers ,
1617 Timeout ,
1718 NotGiven ,
1819 Transport ,
2324from ._compat import cached_property
2425from ._version import __version__
2526from ._streaming import Stream as Stream , AsyncStream as AsyncStream
26- from ._exceptions import APIStatusError , GradientAIError
27+ from ._exceptions import APIStatusError
2728from ._base_client import (
2829 DEFAULT_MAX_RETRIES ,
2930 SyncAPIClient ,
5455
5556class GradientAI (SyncAPIClient ):
5657 # client options
57- api_key : str
58+ api_key : str | None
59+ inference_key : str | None
5860
5961 def __init__ (
6062 self ,
6163 * ,
6264 api_key : str | None = None ,
65+ inference_key : str | None = None ,
6366 base_url : str | httpx .URL | None = None ,
6467 timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
6568 max_retries : int = DEFAULT_MAX_RETRIES ,
@@ -81,16 +84,18 @@ def __init__(
8184 ) -> None :
8285 """Construct a new synchronous GradientAI client instance.
8386
84- This automatically infers the `api_key` argument from the `GRADIENTAI_API_KEY` environment variable if it is not provided.
87+ This automatically infers the following arguments from their corresponding environment variables if they are not provided:
88+ - `api_key` from `GRADIENTAI_API_KEY`
89+ - `inference_key` from `GRADIENTAI_API_KEY`
8590 """
8691 if api_key is None :
8792 api_key = os .environ .get ("GRADIENTAI_API_KEY" )
88- if api_key is None :
89- raise GradientAIError (
90- "The api_key client option must be set either by passing api_key to the client or by setting the GRADIENTAI_API_KEY environment variable"
91- )
9293 self .api_key = api_key
9394
95+ if inference_key is None :
96+ inference_key = os .environ .get ("GRADIENTAI_API_KEY" )
97+ self .inference_key = inference_key
98+
9499 if base_url is None :
95100 base_url = os .environ .get ("GRADIENT_AI_BASE_URL" )
96101 self ._base_url_overridden = base_url is not None
@@ -167,6 +172,8 @@ def qs(self) -> Querystring:
167172 @override
168173 def auth_headers (self ) -> dict [str , str ]:
169174 api_key = self .api_key
175+ if api_key is None :
176+ return {}
170177 return {"Authorization" : f"Bearer { api_key } " }
171178
172179 @property
@@ -178,10 +185,22 @@ def default_headers(self) -> dict[str, str | Omit]:
178185 ** self ._custom_headers ,
179186 }
180187
188+ @override
189+ def _validate_headers (self , headers : Headers , custom_headers : Headers ) -> None :
190+ if self .api_key and headers .get ("Authorization" ):
191+ return
192+ if isinstance (custom_headers .get ("Authorization" ), Omit ):
193+ return
194+
195+ raise TypeError (
196+ '"Could not resolve authentication method. Expected the api_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
197+ )
198+
181199 def copy (
182200 self ,
183201 * ,
184202 api_key : str | None = None ,
203+ inference_key : str | None = None ,
185204 base_url : str | httpx .URL | None = None ,
186205 timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
187206 http_client : httpx .Client | None = None ,
@@ -216,6 +235,7 @@ def copy(
216235 http_client = http_client or self ._client
217236 client = self .__class__ (
218237 api_key = api_key or self .api_key ,
238+ inference_key = inference_key or self .inference_key ,
219239 base_url = base_url or self .base_url ,
220240 timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
221241 http_client = http_client ,
@@ -267,12 +287,14 @@ def _make_status_error(
267287
268288class AsyncGradientAI (AsyncAPIClient ):
269289 # client options
270- api_key : str
290+ api_key : str | None
291+ inference_key : str | None
271292
272293 def __init__ (
273294 self ,
274295 * ,
275296 api_key : str | None = None ,
297+ inference_key : str | None = None ,
276298 base_url : str | httpx .URL | None = None ,
277299 timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
278300 max_retries : int = DEFAULT_MAX_RETRIES ,
@@ -294,16 +316,18 @@ def __init__(
294316 ) -> None :
295317 """Construct a new async AsyncGradientAI client instance.
296318
297- This automatically infers the `api_key` argument from the `GRADIENTAI_API_KEY` environment variable if it is not provided.
319+ This automatically infers the following arguments from their corresponding environment variables if they are not provided:
320+ - `api_key` from `GRADIENTAI_API_KEY`
321+ - `inference_key` from `GRADIENTAI_API_KEY`
298322 """
299323 if api_key is None :
300324 api_key = os .environ .get ("GRADIENTAI_API_KEY" )
301- if api_key is None :
302- raise GradientAIError (
303- "The api_key client option must be set either by passing api_key to the client or by setting the GRADIENTAI_API_KEY environment variable"
304- )
305325 self .api_key = api_key
306326
327+ if inference_key is None :
328+ inference_key = os .environ .get ("GRADIENTAI_API_KEY" )
329+ self .inference_key = inference_key
330+
307331 if base_url is None :
308332 base_url = os .environ .get ("GRADIENT_AI_BASE_URL" )
309333 self ._base_url_overridden = base_url is not None
@@ -380,6 +404,8 @@ def qs(self) -> Querystring:
380404 @override
381405 def auth_headers (self ) -> dict [str , str ]:
382406 api_key = self .api_key
407+ if api_key is None :
408+ return {}
383409 return {"Authorization" : f"Bearer { api_key } " }
384410
385411 @property
@@ -391,10 +417,22 @@ def default_headers(self) -> dict[str, str | Omit]:
391417 ** self ._custom_headers ,
392418 }
393419
420+ @override
421+ def _validate_headers (self , headers : Headers , custom_headers : Headers ) -> None :
422+ if self .api_key and headers .get ("Authorization" ):
423+ return
424+ if isinstance (custom_headers .get ("Authorization" ), Omit ):
425+ return
426+
427+ raise TypeError (
428+ '"Could not resolve authentication method. Expected the api_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
429+ )
430+
394431 def copy (
395432 self ,
396433 * ,
397434 api_key : str | None = None ,
435+ inference_key : str | None = None ,
398436 base_url : str | httpx .URL | None = None ,
399437 timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
400438 http_client : httpx .AsyncClient | None = None ,
@@ -429,6 +467,7 @@ def copy(
429467 http_client = http_client or self ._client
430468 client = self .__class__ (
431469 api_key = api_key or self .api_key ,
470+ inference_key = inference_key or self .inference_key ,
432471 base_url = base_url or self .base_url ,
433472 timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
434473 http_client = http_client ,
0 commit comments