Skip to content

Commit 5d38e2e

Browse files
committed
feat: use inference key for chat.completions.create()
1 parent e92c54b commit 5d38e2e

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

src/do_gradientai/resources/chat/completions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ def create(
153153
154154
timeout: Override the client-level default timeout for this request, in seconds
155155
"""
156+
157+
# This method requires an inference_key to be set via client argument or environment variable
158+
if not hasattr(self._client, "inference_key") or not self._client.inference_key:
159+
raise TypeError(
160+
"Could not resolve authentication method. Expected the inference_key to be set for chat completions."
161+
)
162+
headers = extra_headers or {}
163+
headers = {"Authorization": f"Bearer {self._client.inference_key}", **headers}
164+
156165
return self._post(
157166
"/chat/completions"
158167
if self._client._base_url_overridden
@@ -180,7 +189,7 @@ def create(
180189
completion_create_params.CompletionCreateParams,
181190
),
182191
options=make_request_options(
183-
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
192+
extra_headers=headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
184193
),
185194
cast_to=CompletionCreateResponse,
186195
)
@@ -316,6 +325,15 @@ async def create(
316325
317326
timeout: Override the client-level default timeout for this request, in seconds
318327
"""
328+
329+
# This method requires an inference_key to be set via client argument or environment variable
330+
if not hasattr(self._client, "inference_key") or not self._client.inference_key:
331+
raise TypeError(
332+
"Could not resolve authentication method. Expected the inference_key to be set for chat completions."
333+
)
334+
headers = extra_headers or {}
335+
headers = {"Authorization": f"Bearer {self._client.inference_key}", **headers}
336+
319337
return await self._post(
320338
"/chat/completions"
321339
if self._client._base_url_overridden
@@ -343,7 +361,7 @@ async def create(
343361
completion_create_params.CompletionCreateParams,
344362
),
345363
options=make_request_options(
346-
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
364+
extra_headers=headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
347365
),
348366
cast_to=CompletionCreateResponse,
349367
)

0 commit comments

Comments
 (0)