Skip to content

Commit d27046d

Browse files
committed
fix: fix validation for inference_key and agent_key auth
1 parent 31809e8 commit d27046d

File tree

4 files changed

+26
-10
lines changed

4 files changed

+26
-10
lines changed

src/do_gradientai/_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,13 @@ def default_headers(self) -> dict[str, str | Omit]:
272272

273273
@override
274274
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
275-
if self.api_key and headers.get("Authorization"):
275+
if (self.api_key or self.agent_key or self.inference_key) and headers.get("Authorization"):
276276
return
277277
if isinstance(custom_headers.get("Authorization"), Omit):
278278
return
279279

280280
raise TypeError(
281-
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
281+
'"Could not resolve authentication method. Expected api_key, agent_key, or inference_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
282282
)
283283

284284
def copy(
@@ -569,13 +569,13 @@ def default_headers(self) -> dict[str, str | Omit]:
569569

570570
@override
571571
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
572-
if self.api_key and headers.get("Authorization"):
572+
if (self.api_key or self.agent_key or self.inference_key) and headers.get("Authorization"):
573573
return
574574
if isinstance(custom_headers.get("Authorization"), Omit):
575575
return
576576

577577
raise TypeError(
578-
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
578+
'"Could not resolve authentication method. Expected api_key, agent_key, or inference_key to be set. Or for the `Authorization` headers to be explicitly omitted"'
579579
)
580580

581581
def copy(

src/do_gradientai/resources/agents/chat/completions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,14 @@ def create(
470470
extra_body: Body | None = None,
471471
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
472472
) -> CompletionCreateResponse | Stream[ChatCompletionChunk]:
473+
# This method requires an agent_key to be set via client argument or environment variable
474+
if not self._client.agent_key:
475+
raise TypeError(
476+
"Could not resolve authentication method. Expected agent_key to be set for chat completions."
477+
)
478+
headers = extra_headers or {}
479+
headers = {"Authorization": f"Bearer {self._client.agent_key}", **headers}
480+
473481
return self._post(
474482
"/chat/completions?agent=true"
475483
if self._client._base_url_overridden
@@ -501,7 +509,7 @@ def create(
501509
else completion_create_params.CompletionCreateParamsNonStreaming,
502510
),
503511
options=make_request_options(
504-
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
512+
extra_headers=headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
505513
),
506514
cast_to=CompletionCreateResponse,
507515
stream=stream or False,
@@ -953,6 +961,14 @@ async def create(
953961
extra_body: Body | None = None,
954962
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
955963
) -> CompletionCreateResponse | AsyncStream[ChatCompletionChunk]:
964+
# This method requires an agent_key to be set via client argument or environment variable
965+
if not self._client.agent_key:
966+
raise TypeError(
967+
"Could not resolve authentication method. Expected agent_key to be set for chat completions."
968+
)
969+
headers = extra_headers or {}
970+
headers = {"Authorization": f"Bearer {self._client.agent_key}", **headers}
971+
956972
return await self._post(
957973
"/chat/completions?agent=true"
958974
if self._client._base_url_overridden
@@ -984,7 +1000,7 @@ async def create(
9841000
else completion_create_params.CompletionCreateParamsNonStreaming,
9851001
),
9861002
options=make_request_options(
987-
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
1003+
extra_headers=headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
9881004
),
9891005
cast_to=CompletionCreateResponse,
9901006
stream=stream or False,

src/do_gradientai/resources/chat/completions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def create(
464464
# This method requires an inference_key to be set via client argument or environment variable
465465
if not self._client.inference_key:
466466
raise TypeError(
467-
"Could not resolve authentication method. Expected the inference_key to be set for chat completions."
467+
"Could not resolve authentication method. Expected inference_key to be set for chat completions."
468468
)
469469
headers = extra_headers or {}
470470
headers = {"Authorization": f"Bearer {self._client.inference_key}", **headers}
@@ -946,7 +946,7 @@ async def create(
946946
# This method requires an inference_key to be set via client argument or environment variable
947947
if not hasattr(self._client, "inference_key") or not self._client.inference_key:
948948
raise TypeError(
949-
"Could not resolve authentication method. Expected the inference_key to be set for chat completions."
949+
"Could not resolve authentication method. Expected inference_key to be set for chat completions."
950950
)
951951
headers = extra_headers or {}
952952
headers = {"Authorization": f"Bearer {self._client.inference_key}", **headers}

tests/test_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def test_validate_headers(self) -> None:
414414

415415
with pytest.raises(
416416
TypeError,
417-
match="Could not resolve authentication method. Expected the api_key to be set. Or for the `Authorization` headers to be explicitly omitted",
417+
match="Could not resolve authentication method. Expected api_key, agent_key, or inference_key to be set. Or for the `Authorization` headers to be explicitly omitted",
418418
):
419419
client2._build_request(FinalRequestOptions(method="get", url="/foo"))
420420

@@ -1416,7 +1416,7 @@ def test_validate_headers(self) -> None:
14161416

14171417
with pytest.raises(
14181418
TypeError,
1419-
match="Could not resolve authentication method. Expected the api_key to be set. Or for the `Authorization` headers to be explicitly omitted",
1419+
match="Could not resolve authentication method. Expected api_key, agent_key, or inference_key to be set. Or for the `Authorization` headers to be explicitly omitted",
14201420
):
14211421
client2._build_request(FinalRequestOptions(method="get", url="/foo"))
14221422

0 commit comments

Comments
 (0)