diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 05a89405..ec9a4813 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,13 +36,13 @@ jobs: run: ./scripts/lint build: - if: github.repository == 'stainless-sdks/gradient-python' && (github.event_name == 'push' || github.event.pull_request.head.repo.fork) + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork timeout-minutes: 10 name: build permissions: contents: read id-token: write - runs-on: depot-ubuntu-24.04 + runs-on: ${{ github.repository == 'stainless-sdks/gradient-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} steps: - uses: actions/checkout@v4 @@ -61,12 +61,14 @@ jobs: run: rye build - name: Get GitHub OIDC Token + if: github.repository == 'stainless-sdks/gradient-python' id: github-oidc uses: actions/github-script@v6 with: script: core.setOutput('github_token', await core.getIDToken()); - name: Upload tarball + if: github.repository == 'stainless-sdks/gradient-python' env: URL: https://pkg.stainless.com/s AUTH: ${{ steps.github-oidc.outputs.github_token }} diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 2ce88448..9dcd5cc8 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "3.0.0-beta.4" + ".": "3.0.0-beta.5" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 7b81dd11..4a621094 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 170 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/digitalocean%2Fgradient-9aca3802735e1375125412aa28ac36bf2175144b8218610a73d2e7f775694dff.yml +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/digitalocean%2Fgradient-621c3ebf5011c5ca508f78fccbea17de4ca6b35bfe99578c1ae2265021578d6f.yml openapi_spec_hash: e29d14e3e4679fcf22b3e760e49931b1 -config_hash: 99e3cd5dde0beb796f4547410869f726 +config_hash: 6c8d569b60ae6536708a165b72ff838f diff --git a/CHANGELOG.md b/CHANGELOG.md index 351216f4..85fdc0d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## 3.0.0-beta.5 (2025-09-08) + +Full Changelog: [v3.0.0-beta.4...v3.0.0-beta.5](https://github.com/digitalocean/gradient-python/compare/v3.0.0-beta.4...v3.0.0-beta.5) + +### Features + +* **api:** manual updates ([044a233](https://github.com/digitalocean/gradient-python/commit/044a2339f9ae89facbed403d8240d1e4cf3e9c1f)) +* **api:** manual updates ([0e8fd1b](https://github.com/digitalocean/gradient-python/commit/0e8fd1b364751ec933cadf02be693afa63a67029)) + + +### Bug Fixes + +* avoid newer type syntax ([3d5c35c](https://github.com/digitalocean/gradient-python/commit/3d5c35ca11b4c7344308f7fbd7cd98ec44dd65a0)) + + +### Chores + +* **internal:** add Sequence related utils ([2997cfc](https://github.com/digitalocean/gradient-python/commit/2997cfc25bf46b4cc9faf9f0f22cb4680cadca8b)) +* **internal:** change ci workflow machines ([5f41b3d](https://github.com/digitalocean/gradient-python/commit/5f41b3d956bf1ae25f90b862d5057c16b06e78a3)) +* **internal:** update pyright exclude list ([2a0d1a2](https://github.com/digitalocean/gradient-python/commit/2a0d1a2b174990d6b081ff764b13949b4dfa107f)) +* update github action ([369c5d9](https://github.com/digitalocean/gradient-python/commit/369c5d982cfadfaaaeda9481b2c9249e3f87423d)) + ## 3.0.0-beta.4 (2025-08-12) Full Changelog: [v3.0.0-beta.3...v3.0.0-beta.4](https://github.com/digitalocean/gradient-python/compare/v3.0.0-beta.3...v3.0.0-beta.4) diff --git a/pyproject.toml b/pyproject.toml index 3d37f719..bde954ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "gradient" -version = "3.0.0-beta.4" +version = "3.0.0-beta.5" description = "The official Python library for the Gradient API" dynamic = ["readme"] license = "Apache-2.0" @@ -131,7 +131,12 @@ filterwarnings = ["error"] typeCheckingMode = "strict" pythonVersion = "3.8" -exclude = ["_dev", ".venv", ".nox"] +exclude = [ + "_dev", + ".venv", + ".nox", + ".git", +] reportImplicitOverride = true reportOverlappingOverload = false diff --git a/src/gradient/_client.py b/src/gradient/_client.py index f5866900..74b57d84 100644 --- a/src/gradient/_client.py +++ b/src/gradient/_client.py @@ -113,6 +113,8 @@ def __init__( - `access_token` from `DIGITALOCEAN_ACCESS_TOKEN` - `model_access_key` from `GRADIENT_MODEL_ACCESS_KEY` - `agent_access_key` from `GRADIENT_AGENT_ACCESS_KEY` + - `agent_endpoint` from `GRADIENT_AGENT_ENDPOINT` + - `inference_endpoint` from `GRADIENT_INFERENCE_ENDPOINT` """ if access_token is None: if api_key is not None: @@ -149,10 +151,7 @@ def __init__( self._agent_endpoint = agent_endpoint if inference_endpoint is None: - inference_endpoint = os.environ.get("GRADIENT_INFERENCE_ENDPOINT") - if inference_endpoint is None: - inference_endpoint = "https://inference.do-ai.run" - + inference_endpoint = os.environ.get("GRADIENT_INFERENCE_ENDPOINT") or "inference.do-ai.run" self.inference_endpoint = inference_endpoint if base_url is None: @@ -267,9 +266,7 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if ( - self.access_token or self.agent_access_key or self.model_access_key - ) and headers.get("Authorization"): + if (self.access_token or self.agent_access_key or self.model_access_key) and headers.get("Authorization"): return if isinstance(custom_headers.get("Authorization"), Omit): return @@ -303,14 +300,10 @@ def copy( Create a new client instance re-using the same options given to the current client with optional overriding. """ if default_headers is not None and set_default_headers is not None: - raise ValueError( - "The `default_headers` and `set_default_headers` arguments are mutually exclusive" - ) + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") if default_query is not None and set_default_query is not None: - raise ValueError( - "The `default_query` and `set_default_query` arguments are mutually exclusive" - ) + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") headers = self._custom_headers if default_headers is not None: @@ -358,14 +351,10 @@ def _make_status_error( return _exceptions.BadRequestError(err_msg, response=response, body=body) if response.status_code == 401: - return _exceptions.AuthenticationError( - err_msg, response=response, body=body - ) + return _exceptions.AuthenticationError(err_msg, response=response, body=body) if response.status_code == 403: - return _exceptions.PermissionDeniedError( - err_msg, response=response, body=body - ) + return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) if response.status_code == 404: return _exceptions.NotFoundError(err_msg, response=response, body=body) @@ -374,17 +363,13 @@ def _make_status_error( return _exceptions.ConflictError(err_msg, response=response, body=body) if response.status_code == 422: - return _exceptions.UnprocessableEntityError( - err_msg, response=response, body=body - ) + return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) if response.status_code == 429: return _exceptions.RateLimitError(err_msg, response=response, body=body) if response.status_code >= 500: - return _exceptions.InternalServerError( - err_msg, response=response, body=body - ) + return _exceptions.InternalServerError(err_msg, response=response, body=body) return APIStatusError(err_msg, response=response, body=body) @@ -432,6 +417,8 @@ def __init__( - `access_token` from `DIGITALOCEAN_ACCESS_TOKEN` - `model_access_key` from `GRADIENT_MODEL_ACCESS_KEY` - `agent_access_key` from `GRADIENT_AGENT_ACCESS_KEY` + - `agent_endpoint` from `GRADIENT_AGENT_ENDPOINT` + - `inference_endpoint` from `GRADIENT_INFERENCE_ENDPOINT` """ if access_token is None: if api_key is not None: @@ -463,8 +450,12 @@ def __init__( agent_access_key = os.environ.get("GRADIENT_AGENT_KEY") self.agent_access_key = agent_access_key + if agent_endpoint is None: + agent_endpoint = os.environ.get("GRADIENT_AGENT_ENDPOINT") self._agent_endpoint = agent_endpoint + if inference_endpoint is None: + inference_endpoint = os.environ.get("GRADIENT_INFERENCE_ENDPOINT") or "inference.do-ai.run" self.inference_endpoint = inference_endpoint if base_url is None: @@ -579,9 +570,7 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: - if ( - self.access_token or self.agent_access_key or self.model_access_key - ) and headers.get("Authorization"): + if (self.access_token or self.agent_access_key or self.model_access_key) and headers.get("Authorization"): return if isinstance(custom_headers.get("Authorization"), Omit): return @@ -615,14 +604,10 @@ def copy( Create a new client instance re-using the same options given to the current client with optional overriding. """ if default_headers is not None and set_default_headers is not None: - raise ValueError( - "The `default_headers` and `set_default_headers` arguments are mutually exclusive" - ) + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") if default_query is not None and set_default_query is not None: - raise ValueError( - "The `default_query` and `set_default_query` arguments are mutually exclusive" - ) + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") headers = self._custom_headers if default_headers is not None: @@ -670,14 +655,10 @@ def _make_status_error( return _exceptions.BadRequestError(err_msg, response=response, body=body) if response.status_code == 401: - return _exceptions.AuthenticationError( - err_msg, response=response, body=body - ) + return _exceptions.AuthenticationError(err_msg, response=response, body=body) if response.status_code == 403: - return _exceptions.PermissionDeniedError( - err_msg, response=response, body=body - ) + return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) if response.status_code == 404: return _exceptions.NotFoundError(err_msg, response=response, body=body) @@ -686,17 +667,13 @@ def _make_status_error( return _exceptions.ConflictError(err_msg, response=response, body=body) if response.status_code == 422: - return _exceptions.UnprocessableEntityError( - err_msg, response=response, body=body - ) + return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) if response.status_code == 429: return _exceptions.RateLimitError(err_msg, response=response, body=body) if response.status_code >= 500: - return _exceptions.InternalServerError( - err_msg, response=response, body=body - ) + return _exceptions.InternalServerError(err_msg, response=response, body=body) return APIStatusError(err_msg, response=response, body=body) @@ -915,9 +892,7 @@ def knowledge_bases( AsyncKnowledgeBasesResourceWithStreamingResponse, ) - return AsyncKnowledgeBasesResourceWithStreamingResponse( - self._client.knowledge_bases - ) + return AsyncKnowledgeBasesResourceWithStreamingResponse(self._client.knowledge_bases) @cached_property def models(self) -> models.AsyncModelsResourceWithStreamingResponse: diff --git a/src/gradient/_models.py b/src/gradient/_models.py index b8387ce9..92f7c10b 100644 --- a/src/gradient/_models.py +++ b/src/gradient/_models.py @@ -304,7 +304,7 @@ def model_dump( exclude_none=exclude_none, ) - return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped + return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped @override def model_dump_json( diff --git a/src/gradient/_types.py b/src/gradient/_types.py index b44bb2d9..32375713 100644 --- a/src/gradient/_types.py +++ b/src/gradient/_types.py @@ -13,10 +13,21 @@ Mapping, TypeVar, Callable, + Iterator, Optional, Sequence, ) -from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import ( + Set, + Literal, + Protocol, + TypeAlias, + TypedDict, + SupportsIndex, + overload, + override, + runtime_checkable, +) import httpx import pydantic @@ -217,3 +228,26 @@ class _GenericAlias(Protocol): class HttpxSendArgs(TypedDict, total=False): auth: httpx.Auth follow_redirects: bool + + +_T_co = TypeVar("_T_co", covariant=True) + + +if TYPE_CHECKING: + # This works because str.__contains__ does not accept object (either in typeshed or at runtime) + # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285 + class SequenceNotStr(Protocol[_T_co]): + @overload + def __getitem__(self, index: SupportsIndex, /) -> _T_co: ... + @overload + def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ... + def __contains__(self, value: object, /) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T_co]: ... + def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ... + def count(self, value: Any, /) -> int: ... + def __reversed__(self) -> Iterator[_T_co]: ... +else: + # just point this to a normal `Sequence` at runtime to avoid having to special case + # deserializing our custom sequence type + SequenceNotStr = Sequence diff --git a/src/gradient/_utils/__init__.py b/src/gradient/_utils/__init__.py index d4fda26f..ca547ce5 100644 --- a/src/gradient/_utils/__init__.py +++ b/src/gradient/_utils/__init__.py @@ -38,6 +38,7 @@ extract_type_arg as extract_type_arg, is_iterable_type as is_iterable_type, is_required_type as is_required_type, + is_sequence_type as is_sequence_type, is_annotated_type as is_annotated_type, is_type_alias_type as is_type_alias_type, strip_annotated_type as strip_annotated_type, diff --git a/src/gradient/_utils/_typing.py b/src/gradient/_utils/_typing.py index 1bac9542..845cd6b2 100644 --- a/src/gradient/_utils/_typing.py +++ b/src/gradient/_utils/_typing.py @@ -26,6 +26,11 @@ def is_list_type(typ: type) -> bool: return (get_origin(typ) or typ) == list +def is_sequence_type(typ: type) -> bool: + origin = get_origin(typ) or typ + return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence + + def is_iterable_type(typ: type) -> bool: """If the given type is `typing.Iterable[T]`""" origin = get_origin(typ) or typ diff --git a/src/gradient/_version.py b/src/gradient/_version.py index 428a5fa9..c7adeab4 100644 --- a/src/gradient/_version.py +++ b/src/gradient/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "gradient" -__version__ = "3.0.0-beta.4" # x-release-please-version +__version__ = "3.0.0-beta.5" # x-release-please-version diff --git a/src/gradient/resources/agents/chat/completions.py b/src/gradient/resources/agents/chat/completions.py index 540a7890..88d6c241 100644 --- a/src/gradient/resources/agents/chat/completions.py +++ b/src/gradient/resources/agents/chat/completions.py @@ -472,11 +472,9 @@ def create( headers = {"Authorization": f"Bearer {self._client.agent_access_key}", **headers} return self._post( - ( - "/chat/completions?agent=true" - if self._client._base_url_overridden - else f"{self._client.agent_endpoint}/api/v1/chat/completions?agent=true" - ), + "/chat/completions?agent=true" + if self._client._base_url_overridden + else f"{self._client.agent_endpoint}/api/v1/chat/completions?agent=true", body=maybe_transform( { "messages": messages, @@ -960,11 +958,9 @@ async def create( headers = {"Authorization": f"Bearer {self._client.agent_access_key}", **headers} return await self._post( - ( - "/chat/completions?agent=true" - if self._client._base_url_overridden - else f"{self._client.agent_endpoint}/api/v1/chat/completions?agent=true" - ), + "/chat/completions?agent=true" + if self._client._base_url_overridden + else f"{self._client.agent_endpoint}/api/v1/chat/completions?agent=true", body=await async_maybe_transform( { "messages": messages, diff --git a/src/gradient/resources/chat/completions.py b/src/gradient/resources/chat/completions.py index 18b2a17a..3a412b10 100644 --- a/src/gradient/resources/chat/completions.py +++ b/src/gradient/resources/chat/completions.py @@ -62,9 +62,7 @@ def create( presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -193,9 +191,7 @@ def create( n: Optional[int] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -323,9 +319,7 @@ def create( n: Optional[int] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -453,9 +447,7 @@ def create( presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -481,11 +473,9 @@ def create( } return self._post( - ( - "/chat/completions" - if self._client._base_url_overridden - else f"{self._client.inference_endpoint}/v1/chat/completions" - ), + "/chat/completions" + if self._client._base_url_overridden + else f"{self._client.inference_endpoint}/v1/chat/completions", body=maybe_transform( { "messages": messages, @@ -562,9 +552,7 @@ async def create( presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -693,9 +681,7 @@ async def create( n: Optional[int] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -823,9 +809,7 @@ async def create( n: Optional[int] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -953,9 +937,7 @@ async def create( presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream_options: ( - Optional[completion_create_params.StreamOptions] | NotGiven - ) = NOT_GIVEN, + stream_options: (Optional[completion_create_params.StreamOptions] | NotGiven) = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -970,10 +952,7 @@ async def create( timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> CompletionCreateResponse | AsyncStream[ChatCompletionChunk]: # This method requires an model_access_key to be set via client argument or environment variable - if ( - not hasattr(self._client, "model_access_key") - or not self._client.model_access_key - ): + if not hasattr(self._client, "model_access_key") or not self._client.model_access_key: raise TypeError( "Could not resolve authentication method. Expected model_access_key to be set for chat completions." ) @@ -984,11 +963,9 @@ async def create( } return await self._post( - ( - "/chat/completions" - if self._client._base_url_overridden - else f"{self._client.inference_endpoint}/chat/completions" - ), + "/chat/completions" + if self._client._base_url_overridden + else f"{self._client.inference_endpoint}/v1/chat/completions", body=await async_maybe_transform( { "messages": messages, diff --git a/tests/test_client.py b/tests/test_client.py index 347c89aa..9422604d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -59,9 +59,7 @@ def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float: def _get_open_connections(client: Gradient | AsyncGradient) -> int: transport = client._client._transport - assert isinstance(transport, httpx.HTTPTransport) or isinstance( - transport, httpx.AsyncHTTPTransport - ) + assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) pool = transport._pool return len(pool._requests) @@ -78,9 +76,7 @@ class TestGradient: @pytest.mark.respx(base_url=base_url) def test_raw_response(self, respx_mock: MockRouter) -> None: - respx_mock.post("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) response = self.client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 @@ -228,9 +224,7 @@ def test_copy_signature(self) -> None: continue copy_param = copy_signature.parameters.get(name) - assert ( - copy_param is not None - ), f"copy() signature is missing the {name} param" + assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif( sys.version_info >= (3, 10), @@ -260,9 +254,7 @@ def build_request(options: FinalRequestOptions) -> None: tracemalloc.stop() - def add_leak( - leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff - ) -> None: + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: if diff.count == 0: # Avoid false positives by considering only leaks (i.e. allocations that persist). return @@ -301,9 +293,7 @@ def add_leak( raise AssertionError() def test_request_timeout(self) -> None: - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT @@ -339,9 +329,7 @@ def test_http_client_timeout_option(self) -> None: http_client=http_client, ) - request = client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) @@ -356,9 +344,7 @@ def test_http_client_timeout_option(self) -> None: http_client=http_client, ) - request = client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT @@ -373,9 +359,7 @@ def test_http_client_timeout_option(self) -> None: http_client=http_client, ) - request = client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default @@ -446,9 +430,7 @@ def test_validate_headers(self) -> None: client2._build_request(FinalRequestOptions(method="get", url="/foo")) request2 = client2._build_request( - FinalRequestOptions( - method="get", url="/foo", headers={"Authorization": Omit()} - ) + FinalRequestOptions(method="get", url="/foo", headers={"Authorization": Omit()}) ) assert request2.headers.get("Authorization") is None @@ -520,9 +502,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options( - default_headers={"X-Bar": "true"} - )._build_request( + request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -579,9 +559,7 @@ def test_multipart_repeating_array(self, client: Gradient) -> None: FinalRequestOptions.construct( method="post", url="/foo", - headers={ - "Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82" - }, + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, json_data={"array": ["foo", "bar"]}, files=[("foo.txt", b"hello world")], ) @@ -613,9 +591,7 @@ class Model1(BaseModel): class Model2(BaseModel): foo: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) @@ -631,9 +607,7 @@ class Model1(BaseModel): class Model2(BaseModel): foo: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) @@ -646,9 +620,7 @@ class Model2(BaseModel): assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - def test_non_application_json_content_type_for_json_data( - self, respx_mock: MockRouter - ) -> None: + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -821,9 +793,7 @@ def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: class Model(BaseModel): foo: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, json={"foo": {"invalid": True}}) - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: self.client.get("/foo", cast_to=Model) @@ -846,13 +816,9 @@ def test_default_stream_cls(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str - respx_mock.post("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - stream = self.client.post( - "/foo", cast_to=Model, stream=True, stream_cls=Stream[Model] - ) + stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model]) assert isinstance(stream, Stream) stream.response.close() @@ -861,9 +827,7 @@ def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, text="my-custom-format") - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) strict_client = Gradient( base_url=base_url, @@ -909,9 +873,7 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - def test_parse_retry_after_header( - self, remaining_retries: int, retry_after: str, timeout: float - ) -> None: + def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: client = Gradient( base_url=base_url, access_token=access_token, @@ -922,21 +884,13 @@ def test_parse_retry_after_header( headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout( - remaining_retries, options, headers - ) + calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_timeout_errors_doesnt_leak( - self, respx_mock: MockRouter, client: Gradient - ) -> None: - respx_mock.post("/chat/completions").mock( - side_effect=httpx.TimeoutException("Test timeout error") - ) + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Gradient) -> None: + respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): client.chat.completions.with_streaming_response.create( @@ -951,13 +905,9 @@ def test_retrying_timeout_errors_doesnt_leak( assert _get_open_connections(self.client) == 0 - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_status_errors_doesnt_leak( - self, respx_mock: MockRouter, client: Gradient - ) -> None: + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Gradient) -> None: respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): @@ -973,9 +923,7 @@ def test_retrying_status_errors_doesnt_leak( assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @pytest.mark.parametrize("failure_mode", ["status", "exception"]) def test_retries_taken( @@ -1011,15 +959,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.retries_taken == failures_before_success - assert ( - int(response.http_request.headers.get("x-stainless-retry-count")) - == failures_before_success - ) + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_omit_retry_count_header( self, client: Gradient, failures_before_success: int, respx_mock: MockRouter @@ -1048,14 +991,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: extra_headers={"x-stainless-retry-count": Omit()}, ) - assert ( - len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 - ) + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_overwrite_retry_count_header( self, client: Gradient, failures_before_success: int, respx_mock: MockRouter @@ -1112,17 +1051,11 @@ def test_default_client_creation(self) -> None: def test_follow_redirects(self, respx_mock: MockRouter) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( - return_value=httpx.Response( - 302, headers={"Location": f"{base_url}/redirected"} - ) - ) - respx_mock.get("/redirected").mock( - return_value=httpx.Response(200, json={"status": "ok"}) + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = self.client.post( - "/redirect", body={"key": "value"}, cast_to=httpx.Response - ) + response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @@ -1130,9 +1063,7 @@ def test_follow_redirects(self, respx_mock: MockRouter) -> None: def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( - return_value=httpx.Response( - 302, headers={"Location": f"{base_url}/redirected"} - ) + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: @@ -1159,9 +1090,7 @@ class TestAsyncGradient: @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_raw_response(self, respx_mock: MockRouter) -> None: - respx_mock.post("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) response = await self.client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 @@ -1310,9 +1239,7 @@ def test_copy_signature(self) -> None: continue copy_param = copy_signature.parameters.get(name) - assert ( - copy_param is not None - ), f"copy() signature is missing the {name} param" + assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif( sys.version_info >= (3, 10), @@ -1342,9 +1269,7 @@ def build_request(options: FinalRequestOptions) -> None: tracemalloc.stop() - def add_leak( - leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff - ) -> None: + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: if diff.count == 0: # Avoid false positives by considering only leaks (i.e. allocations that persist). return @@ -1383,9 +1308,7 @@ def add_leak( raise AssertionError() async def test_request_timeout(self) -> None: - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT @@ -1421,9 +1344,7 @@ async def test_http_client_timeout_option(self) -> None: http_client=http_client, ) - request = client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) @@ -1438,9 +1359,7 @@ async def test_http_client_timeout_option(self) -> None: http_client=http_client, ) - request = client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT @@ -1455,9 +1374,7 @@ async def test_http_client_timeout_option(self) -> None: http_client=http_client, ) - request = client._build_request( - FinalRequestOptions(method="get", url="/foo") - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default @@ -1528,9 +1445,7 @@ def test_validate_headers(self) -> None: client2._build_request(FinalRequestOptions(method="get", url="/foo")) request2 = client2._build_request( - FinalRequestOptions( - method="get", url="/foo", headers={"Authorization": Omit()} - ) + FinalRequestOptions(method="get", url="/foo", headers={"Authorization": Omit()}) ) assert request2.headers.get("Authorization") is None @@ -1602,9 +1517,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options( - default_headers={"X-Bar": "true"} - )._build_request( + request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1661,9 +1574,7 @@ def test_multipart_repeating_array(self, async_client: AsyncGradient) -> None: FinalRequestOptions.construct( method="post", url="/foo", - headers={ - "Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82" - }, + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, json_data={"array": ["foo", "bar"]}, files=[("foo.txt", b"hello world")], ) @@ -1695,13 +1606,9 @@ class Model1(BaseModel): class Model2(BaseModel): foo: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get( - "/foo", cast_to=cast(Any, Union[Model1, Model2]) - ) + response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @@ -1715,28 +1622,20 @@ class Model1(BaseModel): class Model2(BaseModel): foo: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get( - "/foo", cast_to=cast(Any, Union[Model1, Model2]) - ) + response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = await self.client.get( - "/foo", cast_to=cast(Any, Union[Model1, Model2]) - ) + response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - async def test_non_application_json_content_type_for_json_data( - self, respx_mock: MockRouter - ) -> None: + async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -1907,15 +1806,11 @@ async def test_client_context_manager(self) -> None: @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio - async def test_client_response_validation_error( - self, respx_mock: MockRouter - ) -> None: + async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: class Model(BaseModel): foo: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, json={"foo": {"invalid": True}}) - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: await self.client.get("/foo", cast_to=Model) @@ -1939,27 +1834,19 @@ async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str - respx_mock.post("/foo").mock( - return_value=httpx.Response(200, json={"foo": "bar"}) - ) + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - stream = await self.client.post( - "/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model] - ) + stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model]) assert isinstance(stream, AsyncStream) await stream.response.aclose() @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio - async def test_received_text_for_expected_json( - self, respx_mock: MockRouter - ) -> None: + async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str - respx_mock.get("/foo").mock( - return_value=httpx.Response(200, text="my-custom-format") - ) + respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) strict_client = AsyncGradient( base_url=base_url, @@ -2006,9 +1893,7 @@ class Model(BaseModel): ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @pytest.mark.asyncio - async def test_parse_retry_after_header( - self, remaining_retries: int, retry_after: str, timeout: float - ) -> None: + async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: client = AsyncGradient( base_url=base_url, access_token=access_token, @@ -2019,21 +1904,15 @@ async def test_parse_retry_after_header( headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout( - remaining_retries, options, headers - ) + calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak( self, respx_mock: MockRouter, async_client: AsyncGradient ) -> None: - respx_mock.post("/chat/completions").mock( - side_effect=httpx.TimeoutException("Test timeout error") - ) + respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): await async_client.chat.completions.with_streaming_response.create( @@ -2048,9 +1927,7 @@ async def test_retrying_timeout_errors_doesnt_leak( assert _get_open_connections(self.client) == 0 - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_status_errors_doesnt_leak( self, respx_mock: MockRouter, async_client: AsyncGradient @@ -2070,9 +1947,7 @@ async def test_retrying_status_errors_doesnt_leak( assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio @pytest.mark.parametrize("failure_mode", ["status", "exception"]) @@ -2109,15 +1984,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.retries_taken == failures_before_success - assert ( - int(response.http_request.headers.get("x-stainless-retry-count")) - == failures_before_success - ) + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_omit_retry_count_header( @@ -2150,14 +2020,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: extra_headers={"x-stainless-retry-count": Omit()}, ) - assert ( - len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 - ) + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) - @mock.patch( - "gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout - ) + @mock.patch("gradient._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_overwrite_retry_count_header( @@ -2228,24 +2094,18 @@ async def test_main() -> None: return_code = process.poll() if return_code is not None: if return_code != 0: - raise AssertionError( - "calling get_platform using asyncify resulted in a non-zero exit code" - ) + raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code") # success break if time.monotonic() - start_time > timeout: process.kill() - raise AssertionError( - "calling get_platform using asyncify resulted in a hung process" - ) + raise AssertionError("calling get_platform using asyncify resulted in a hung process") time.sleep(0.1) - async def test_proxy_environment_variables( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: # Test that the proxy environment variables are set correctly monkeypatch.setenv("HTTPS_PROXY", "https://example.org") @@ -2271,17 +2131,11 @@ async def test_default_client_creation(self) -> None: async def test_follow_redirects(self, respx_mock: MockRouter) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( - return_value=httpx.Response( - 302, headers={"Location": f"{base_url}/redirected"} - ) - ) - respx_mock.get("/redirected").mock( - return_value=httpx.Response(200, json={"status": "ok"}) + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = await self.client.post( - "/redirect", body={"key": "value"}, cast_to=httpx.Response - ) + response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @@ -2289,9 +2143,7 @@ async def test_follow_redirects(self, respx_mock: MockRouter) -> None: async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( - return_value=httpx.Response( - 302, headers={"Location": f"{base_url}/redirected"} - ) + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: diff --git a/tests/utils.py b/tests/utils.py index e150f00b..ac014538 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import inspect import traceback import contextlib -from typing import Any, TypeVar, Iterator, cast +from typing import Any, TypeVar, Iterator, Sequence, cast from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type @@ -15,6 +15,7 @@ is_list_type, is_union_type, extract_type_arg, + is_sequence_type, is_annotated_type, is_type_alias_type, ) @@ -71,6 +72,13 @@ def assert_matches_type( if is_list_type(type_): return _assert_list_type(type_, value) + if is_sequence_type(type_): + assert isinstance(value, Sequence) + inner_type = get_args(type_)[0] + for entry in value: # type: ignore + assert_type(inner_type, entry) # type: ignore + return + if origin == str: assert isinstance(value, str) elif origin == int: