Skip to content

Commit 108d7cb

Browse files
authored
fix: add proto to default inference url (#52)
1 parent 76aafd7 commit 108d7cb

File tree

1 file changed

+53
-17
lines changed

1 file changed

+53
-17
lines changed

src/gradient/_client.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ def __init__(
133133
self._agent_endpoint = agent_endpoint
134134

135135
if inference_endpoint is None:
136-
inference_endpoint = os.environ.get("GRADIENT_INFERENCE_ENDPOINT") or "inference.do-ai.run"
136+
inference_endpoint = (
137+
os.environ.get("GRADIENT_INFERENCE_ENDPOINT")
138+
or "https://inference.do-ai.run"
139+
)
137140
self.inference_endpoint = inference_endpoint
138141

139142
if base_url is None:
@@ -250,7 +253,9 @@ def default_headers(self) -> dict[str, str | Omit]:
250253

251254
@override
252255
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
253-
if (self.access_token or self.agent_access_key or self.model_access_key) and headers.get("Authorization"):
256+
if (
257+
self.access_token or self.agent_access_key or self.model_access_key
258+
) and headers.get("Authorization"):
254259
return
255260
if isinstance(custom_headers.get("Authorization"), Omit):
256261
return
@@ -283,10 +288,14 @@ def copy(
283288
Create a new client instance re-using the same options given to the current client with optional overriding.
284289
"""
285290
if default_headers is not None and set_default_headers is not None:
286-
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
291+
raise ValueError(
292+
"The `default_headers` and `set_default_headers` arguments are mutually exclusive"
293+
)
287294

288295
if default_query is not None and set_default_query is not None:
289-
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
296+
raise ValueError(
297+
"The `default_query` and `set_default_query` arguments are mutually exclusive"
298+
)
290299

291300
headers = self._custom_headers
292301
if default_headers is not None:
@@ -336,10 +345,14 @@ def _make_status_error(
336345
return _exceptions.BadRequestError(err_msg, response=response, body=body)
337346

338347
if response.status_code == 401:
339-
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
348+
return _exceptions.AuthenticationError(
349+
err_msg, response=response, body=body
350+
)
340351

341352
if response.status_code == 403:
342-
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
353+
return _exceptions.PermissionDeniedError(
354+
err_msg, response=response, body=body
355+
)
343356

344357
if response.status_code == 404:
345358
return _exceptions.NotFoundError(err_msg, response=response, body=body)
@@ -348,13 +361,17 @@ def _make_status_error(
348361
return _exceptions.ConflictError(err_msg, response=response, body=body)
349362

350363
if response.status_code == 422:
351-
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
364+
return _exceptions.UnprocessableEntityError(
365+
err_msg, response=response, body=body
366+
)
352367

353368
if response.status_code == 429:
354369
return _exceptions.RateLimitError(err_msg, response=response, body=body)
355370

356371
if response.status_code >= 500:
357-
return _exceptions.InternalServerError(err_msg, response=response, body=body)
372+
return _exceptions.InternalServerError(
373+
err_msg, response=response, body=body
374+
)
358375
return APIStatusError(err_msg, response=response, body=body)
359376

360377

@@ -422,7 +439,10 @@ def __init__(
422439
self._agent_endpoint = agent_endpoint
423440

424441
if inference_endpoint is None:
425-
inference_endpoint = os.environ.get("GRADIENT_INFERENCE_ENDPOINT") or "inference.do-ai.run"
442+
inference_endpoint = (
443+
os.environ.get("GRADIENT_INFERENCE_ENDPOINT")
444+
or "https://inference.do-ai.run"
445+
)
426446
self.inference_endpoint = inference_endpoint
427447

428448
if base_url is None:
@@ -539,7 +559,9 @@ def default_headers(self) -> dict[str, str | Omit]:
539559

540560
@override
541561
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
542-
if (self.access_token or self.agent_access_key or self.model_access_key) and headers.get("Authorization"):
562+
if (
563+
self.access_token or self.agent_access_key or self.model_access_key
564+
) and headers.get("Authorization"):
543565
return
544566
if isinstance(custom_headers.get("Authorization"), Omit):
545567
return
@@ -572,10 +594,14 @@ def copy(
572594
Create a new client instance re-using the same options given to the current client with optional overriding.
573595
"""
574596
if default_headers is not None and set_default_headers is not None:
575-
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
597+
raise ValueError(
598+
"The `default_headers` and `set_default_headers` arguments are mutually exclusive"
599+
)
576600

577601
if default_query is not None and set_default_query is not None:
578-
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
602+
raise ValueError(
603+
"The `default_query` and `set_default_query` arguments are mutually exclusive"
604+
)
579605

580606
headers = self._custom_headers
581607
if default_headers is not None:
@@ -625,10 +651,14 @@ def _make_status_error(
625651
return _exceptions.BadRequestError(err_msg, response=response, body=body)
626652

627653
if response.status_code == 401:
628-
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
654+
return _exceptions.AuthenticationError(
655+
err_msg, response=response, body=body
656+
)
629657

630658
if response.status_code == 403:
631-
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
659+
return _exceptions.PermissionDeniedError(
660+
err_msg, response=response, body=body
661+
)
632662

633663
if response.status_code == 404:
634664
return _exceptions.NotFoundError(err_msg, response=response, body=body)
@@ -637,13 +667,17 @@ def _make_status_error(
637667
return _exceptions.ConflictError(err_msg, response=response, body=body)
638668

639669
if response.status_code == 422:
640-
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
670+
return _exceptions.UnprocessableEntityError(
671+
err_msg, response=response, body=body
672+
)
641673

642674
if response.status_code == 429:
643675
return _exceptions.RateLimitError(err_msg, response=response, body=body)
644676

645677
if response.status_code >= 500:
646-
return _exceptions.InternalServerError(err_msg, response=response, body=body)
678+
return _exceptions.InternalServerError(
679+
err_msg, response=response, body=body
680+
)
647681
return APIStatusError(err_msg, response=response, body=body)
648682

649683

@@ -862,7 +896,9 @@ def knowledge_bases(
862896
AsyncKnowledgeBasesResourceWithStreamingResponse,
863897
)
864898

865-
return AsyncKnowledgeBasesResourceWithStreamingResponse(self._client.knowledge_bases)
899+
return AsyncKnowledgeBasesResourceWithStreamingResponse(
900+
self._client.knowledge_bases
901+
)
866902

867903
@cached_property
868904
def models(self) -> models.AsyncModelsResourceWithStreamingResponse:

0 commit comments

Comments
 (0)