From ec25bed4069bf306496710759da7c393f6df76e9 Mon Sep 17 00:00:00 2001 From: Lev Vereshchagin Date: Thu, 5 Dec 2024 18:32:54 +0300 Subject: [PATCH] Move back to HTTPX (#13) * Revert "Use niquests (#7)" This reverts commit bfaec30ec9ef68d3258bb6c927da193090762b2f. * Leave good stuff --- .github/workflows/publish.yml | 2 +- Justfile | 10 --- README.md | 12 ++- any_llm_client/clients/openai.py | 68 +++++++++-------- any_llm_client/clients/yandexgpt.py | 66 ++++++++++------- any_llm_client/http.py | 111 +++++++++------------------- any_llm_client/main.py | 14 ++-- any_llm_client/retry.py | 2 +- any_llm_client/sse.py | 11 --- pyproject.toml | 9 +-- tests/conftest.py | 27 ------- tests/test_http.py | 47 +++++------- tests/test_openai_client.py | 73 ++++++++++-------- tests/test_static.py | 8 -- tests/test_yandexgpt_client.py | 71 +++++++++--------- tests/testing_app.py | 28 ------- 16 files changed, 224 insertions(+), 335 deletions(-) delete mode 100644 any_llm_client/sse.py delete mode 100644 tests/testing_app.py diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ed65661..787069e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v4 - uses: extractions/setup-just@v2 - - uses: astral-sh/setup-uv@v4 + - uses: astral-sh/setup-uv@v3 with: enable-cache: true cache-dependency-glob: "**/pyproject.toml" diff --git a/Justfile b/Justfile index a29b3b6..bae709c 100644 --- a/Justfile +++ b/Justfile @@ -10,18 +10,8 @@ lint: uv run --group lint ruff format uv run --group lint mypy . -_test-no-http *args: - uv run pytest --ignore tests/test_http.py {{ args }} - test *args: - #!/bin/bash - uv run litestar --app tests.testing_app:app run & - APP_PID=$! uv run pytest {{ args }} - TEST_RESULT=$? - kill $APP_PID - wait $APP_PID 2>/dev/null - exit $TEST_RESULT publish: rm -rf dist diff --git a/README.md b/README.md index c8345bd..395ac8e 100644 --- a/README.md +++ b/README.md @@ -162,25 +162,23 @@ async with any_llm_client.OpenAIClient(config, ...) as client: #### Timeouts, proxy & other HTTP settings -Pass custom [niquests](https://niquests.readthedocs.io) kwargs to `any_llm_client.get_client()`: +Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`: ```python -import urllib3 +import httpx import any_llm_client async with any_llm_client.get_client( ..., - proxies={"https://api.openai.com": "http://localhost:8030"}, - timeout=urllib3.Timeout(total=10.0, connect=5.0), + mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")}, + timeout=httpx.Timeout(None, connect=5.0), ) as client: ... ``` -`timeout` and `proxies` parameters are special cased here: `niquests.AsyncSession` doesn't receive them by default. - -Default timeout is `urllib3.Timeout(total=None, connect=5.0)`. +Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unlimited on read, write or pool). #### Retries diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index c9ab16e..65bb5d2 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -6,7 +6,8 @@ from http import HTTPStatus import annotated_types -import niquests +import httpx +import httpx_sse import pydantic import typing_extensions @@ -19,9 +20,8 @@ OutOfTokensOrSymbolsError, UserMessage, ) -from any_llm_client.http import HttpClient, HttpStatusError +from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request from any_llm_client.retry import RequestRetryConfig -from any_llm_client.sse import parse_sse_events OPENAI_AUTH_TOKEN_ENV_NAME: typing.Final = "ANY_LLM_CLIENT_OPENAI_AUTH_TOKEN" @@ -99,18 +99,16 @@ def _make_user_assistant_alternate_messages( yield ChatCompletionsMessage(role=current_message_role, content="\n\n".join(current_message_content_chunks)) -def _handle_status_error(error: HttpStatusError) -> typing.NoReturn: - if ( - error.status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in error.content - ): # vLLM - raise OutOfTokensOrSymbolsError(response_content=error.content) - raise LLMError(response_content=error.content) +def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn: + if status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in content: # vLLM + raise OutOfTokensOrSymbolsError(response_content=content) + raise LLMError(response_content=content) @dataclasses.dataclass(slots=True, init=False) class OpenAIClient(LLMClient): config: OpenAIConfig - http_client: HttpClient + httpx_client: httpx.AsyncClient request_retry: RequestRetryConfig def __init__( @@ -118,15 +116,14 @@ def __init__( config: OpenAIConfig, *, request_retry: RequestRetryConfig | None = None, - **niquests_kwargs: typing.Any, # noqa: ANN401 + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> None: self.config = config - self.http_client = HttpClient( - request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs - ) + self.request_retry = request_retry or RequestRetryConfig() + self.httpx_client = get_http_client_from_kwargs(httpx_kwargs) - def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request: - return niquests.Request( + def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: + return self.httpx_client.build_request( method="POST", url=str(self.config.url), json=payload, @@ -155,17 +152,24 @@ async def request_llm_message( **extra or {}, ).model_dump(mode="json") try: - response: typing.Final = await self.http_client.request(self._build_request(payload)) - except HttpStatusError as exception: - _handle_status_error(exception) - return ChatCompletionsNotStreamingResponse.model_validate_json(response).choices[0].message.content + response: typing.Final = await make_http_request( + httpx_client=self.httpx_client, + request_retry=self.request_retry, + build_request=lambda: self._build_request(payload), + ) + except httpx.HTTPStatusError as exception: + _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) + try: + return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content + finally: + await response.aclose() - async def _iter_partial_responses(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]: + async def _iter_partial_responses(self, response: httpx.Response) -> typing.AsyncIterable[str]: text_chunks: typing.Final = [] - async for one_event in parse_sse_events(response): - if one_event.data == "[DONE]": + async for event in httpx_sse.EventSource(response).aiter_sse(): + if event.data == "[DONE]": break - validated_response = ChatCompletionsStreamingEvent.model_validate_json(one_event.data) + validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data) if not (one_chunk := validated_response.choices[0].delta.content): continue text_chunks.append(one_chunk) @@ -183,13 +187,19 @@ async def stream_llm_partial_messages( **extra or {}, ).model_dump(mode="json") try: - async with self.http_client.stream(request=self._build_request(payload)) as response: + async with make_streaming_http_request( + httpx_client=self.httpx_client, + request_retry=self.request_retry, + build_request=lambda: self._build_request(payload), + ) as response: yield self._iter_partial_responses(response) - except HttpStatusError as exception: - _handle_status_error(exception) + except httpx.HTTPStatusError as exception: + content: typing.Final = await exception.response.aread() + await exception.response.aclose() + _handle_status_error(status_code=exception.response.status_code, content=content) async def __aenter__(self) -> typing_extensions.Self: - await self.http_client.__aenter__() + await self.httpx_client.__aenter__() return self async def __aexit__( @@ -198,4 +208,4 @@ async def __aexit__( exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback) + await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback) diff --git a/any_llm_client/clients/yandexgpt.py b/any_llm_client/clients/yandexgpt.py index 305a3f5..10c8818 100644 --- a/any_llm_client/clients/yandexgpt.py +++ b/any_llm_client/clients/yandexgpt.py @@ -6,12 +6,12 @@ from http import HTTPStatus import annotated_types -import niquests +import httpx import pydantic import typing_extensions from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage -from any_llm_client.http import HttpClient, HttpStatusError +from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request from any_llm_client.retry import RequestRetryConfig @@ -61,34 +61,34 @@ class YandexGPTResponse(pydantic.BaseModel): result: YandexGPTResult -def _handle_status_error(error: HttpStatusError) -> typing.NoReturn: - if error.status_code == HTTPStatus.BAD_REQUEST and ( - b"number of input tokens must be no more than" in error.content - or (b"text length is" in error.content and b"which is outside the range" in error.content) +def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn: + if status_code == HTTPStatus.BAD_REQUEST and ( + b"number of input tokens must be no more than" in content + or (b"text length is" in content and b"which is outside the range" in content) ): - raise OutOfTokensOrSymbolsError(response_content=error.content) - raise LLMError(response_content=error.content) + raise OutOfTokensOrSymbolsError(response_content=content) + raise LLMError(response_content=content) @dataclasses.dataclass(slots=True, init=False) class YandexGPTClient(LLMClient): config: YandexGPTConfig - http_client: HttpClient + httpx_client: httpx.AsyncClient + request_retry: RequestRetryConfig def __init__( self, config: YandexGPTConfig, *, request_retry: RequestRetryConfig | None = None, - **niquests_kwargs: typing.Any, # noqa: ANN401 + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> None: self.config = config - self.http_client = HttpClient( - request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs - ) + self.request_retry = request_retry or RequestRetryConfig() + self.httpx_client = get_http_client_from_kwargs(httpx_kwargs) - def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request: - return niquests.Request( + def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request: + return self.httpx_client.build_request( method="POST", url=str(self.config.url), json=payload, @@ -121,14 +121,18 @@ async def request_llm_message( ) try: - response: typing.Final = await self.http_client.request(self._build_request(payload)) - except HttpStatusError as exception: - _handle_status_error(exception) - - return YandexGPTResponse.model_validate_json(response).result.alternatives[0].message.text - - async def _iter_completion_messages(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]: - async for one_line in response: + response: typing.Final = await make_http_request( + httpx_client=self.httpx_client, + request_retry=self.request_retry, + build_request=lambda: self._build_request(payload), + ) + except httpx.HTTPStatusError as exception: + _handle_status_error(status_code=exception.response.status_code, content=exception.response.content) + + return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text + + async def _iter_completion_messages(self, response: httpx.Response) -> typing.AsyncIterable[str]: + async for one_line in response.aiter_lines(): validated_response = YandexGPTResponse.model_validate_json(one_line) yield validated_response.result.alternatives[0].message.text @@ -141,13 +145,19 @@ async def stream_llm_partial_messages( ) try: - async with self.http_client.stream(request=self._build_request(payload)) as response: + async with make_streaming_http_request( + httpx_client=self.httpx_client, + request_retry=self.request_retry, + build_request=lambda: self._build_request(payload), + ) as response: yield self._iter_completion_messages(response) - except HttpStatusError as exception: - _handle_status_error(exception) + except httpx.HTTPStatusError as exception: + content: typing.Final = await exception.response.aread() + await exception.response.aclose() + _handle_status_error(status_code=exception.response.status_code, content=content) async def __aenter__(self) -> typing_extensions.Self: - await self.http_client.__aenter__() + await self.httpx_client.__aenter__() return self async def __aexit__( @@ -156,4 +166,4 @@ async def __aexit__( exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback) + await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback) diff --git a/any_llm_client/http.py b/any_llm_client/http.py index bb7f23a..7d05030 100644 --- a/any_llm_client/http.py +++ b/any_llm_client/http.py @@ -1,95 +1,52 @@ import contextlib import dataclasses -import types import typing -import niquests +import httpx import stamina -import typing_extensions -import urllib3 from any_llm_client.retry import RequestRetryConfig -DEFAULT_HTTP_TIMEOUT: typing.Final = urllib3.Timeout(total=None, connect=5.0) +DEFAULT_HTTP_TIMEOUT: typing.Final = httpx.Timeout(None, connect=5.0) -@dataclasses.dataclass -class HttpStatusError(Exception): - status_code: int - content: bytes +def get_http_client_from_kwargs(kwargs: dict[str, typing.Any]) -> httpx.AsyncClient: + kwargs_with_defaults: typing.Final = kwargs.copy() + kwargs_with_defaults.setdefault("timeout", DEFAULT_HTTP_TIMEOUT) + return httpx.AsyncClient(**kwargs_with_defaults) -@dataclasses.dataclass(slots=True, init=False) -class HttpClient: - client: niquests.AsyncSession - timeout: urllib3.Timeout - _make_not_streaming_request_with_retries: typing.Callable[[niquests.Request], typing.Awaitable[niquests.Response]] - _make_streaming_request_with_retries: typing.Callable[[niquests.Request], typing.Awaitable[niquests.AsyncResponse]] - _retried_exceptions: typing.ClassVar = (niquests.HTTPError, HttpStatusError) - - def __init__(self, request_retry: RequestRetryConfig, niquests_kwargs: dict[str, typing.Any]) -> None: - modified_kwargs: typing.Final = niquests_kwargs.copy() - self.timeout = modified_kwargs.pop("timeout", DEFAULT_HTTP_TIMEOUT) - proxies: typing.Final = modified_kwargs.pop("proxies", None) - - self.client = niquests.AsyncSession(**modified_kwargs) - if proxies: - self.client.proxies = proxies - - request_retry_dict: typing.Final = dataclasses.asdict(request_retry) - - self._make_not_streaming_request_with_retries = stamina.retry( - on=self._retried_exceptions, **request_retry_dict - )(self._make_not_streaming_request) - self._make_streaming_request_with_retries = stamina.retry(on=self._retried_exceptions, **request_retry_dict)( - self._make_streaming_request - ) - - async def _make_not_streaming_request(self, request: niquests.Request) -> niquests.Response: - response: typing.Final = await self.client.send(self.client.prepare_request(request), timeout=self.timeout) - try: - response.raise_for_status() - except niquests.HTTPError as exception: - raise HttpStatusError(status_code=response.status_code, content=response.content) from exception # type: ignore[arg-type] - finally: - response.close() +async def make_http_request( + *, + httpx_client: httpx.AsyncClient, + request_retry: RequestRetryConfig, + build_request: typing.Callable[[], httpx.Request], +) -> httpx.Response: + @stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry)) + async def make_request_with_retries() -> httpx.Response: + response: typing.Final = await httpx_client.send(build_request()) + response.raise_for_status() return response - async def request(self, request: niquests.Request) -> bytes: - response: typing.Final = await self._make_not_streaming_request_with_retries(request) - return response.content # type: ignore[return-value] - - async def _make_streaming_request(self, request: niquests.Request) -> niquests.AsyncResponse: - response: typing.Final = await self.client.send( - self.client.prepare_request(request), stream=True, timeout=self.timeout - ) - try: - response.raise_for_status() - except niquests.HTTPError as exception: - status_code: typing.Final = response.status_code - content: typing.Final = await response.content # type: ignore[misc] - await response.close() # type: ignore[misc] - raise HttpStatusError(status_code=status_code, content=content) from exception # type: ignore[arg-type] - return response # type: ignore[return-value] + return await make_request_with_retries() - @contextlib.asynccontextmanager - async def stream(self, request: niquests.Request) -> typing.AsyncIterator[typing.AsyncIterable[bytes]]: - response: typing.Final = await self._make_streaming_request_with_retries(request) - try: - response.__aenter__() - yield response.iter_lines() # type: ignore[misc] - finally: - await response.raw.close() # type: ignore[union-attr] - async def __aenter__(self) -> typing_extensions.Self: - await self.client.__aenter__() # type: ignore[no-untyped-call] - return self +@contextlib.asynccontextmanager +async def make_streaming_http_request( + *, + httpx_client: httpx.AsyncClient, + request_retry: RequestRetryConfig, + build_request: typing.Callable[[], httpx.Request], +) -> typing.AsyncIterator[httpx.Response]: + @stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry)) + async def make_request_with_retries() -> httpx.Response: + response: typing.Final = await httpx_client.send(build_request(), stream=True) + response.raise_for_status() + return response - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: types.TracebackType | None, - ) -> None: - await self.client.__aexit__(exc_type, exc_value, traceback) # type: ignore[no-untyped-call] + response: typing.Final = await make_request_with_retries() + try: + yield response + finally: + await response.aclose() diff --git a/any_llm_client/main.py b/any_llm_client/main.py index d9c6c76..d734e67 100644 --- a/any_llm_client/main.py +++ b/any_llm_client/main.py @@ -19,7 +19,7 @@ def get_client( config: AnyLLMConfig, *, request_retry: RequestRetryConfig | None = None, - **niquests_kwargs: typing.Any, # noqa: ANN401 + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> LLMClient: ... else: @@ -28,7 +28,7 @@ def get_client( config: typing.Any, # noqa: ANN401, ARG001 *, request_retry: RequestRetryConfig | None = None, # noqa: ARG001 - **niquests_kwargs: typing.Any, # noqa: ANN401, ARG001 + **httpx_kwargs: typing.Any, # noqa: ANN401, ARG001 ) -> LLMClient: raise AssertionError("unknown LLM config type") @@ -37,24 +37,24 @@ def _( config: YandexGPTConfig, *, request_retry: RequestRetryConfig | None = None, - **niquests_kwargs: typing.Any, # noqa: ANN401 + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> LLMClient: - return YandexGPTClient(config=config, request_retry=request_retry, **niquests_kwargs) + return YandexGPTClient(config=config, request_retry=request_retry, **httpx_kwargs) @get_client.register def _( config: OpenAIConfig, *, request_retry: RequestRetryConfig | None = None, - **niquests_kwargs: typing.Any, # noqa: ANN401 + **httpx_kwargs: typing.Any, # noqa: ANN401 ) -> LLMClient: - return OpenAIClient(config=config, request_retry=request_retry, **niquests_kwargs) + return OpenAIClient(config=config, request_retry=request_retry, **httpx_kwargs) @get_client.register def _( config: MockLLMConfig, *, request_retry: RequestRetryConfig | None = None, # noqa: ARG001 - **niquests_kwargs: typing.Any, # noqa: ANN401, ARG001 + **httpx_kwargs: typing.Any, # noqa: ANN401, ARG001 ) -> LLMClient: return MockLLMClient(config=config) diff --git a/any_llm_client/retry.py b/any_llm_client/retry.py index 3aded39..d043322 100644 --- a/any_llm_client/retry.py +++ b/any_llm_client/retry.py @@ -4,7 +4,7 @@ @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) class RequestRetryConfig: - """Request retry configuration that is passed to `stamina.retry`. Applies to niquests.HTTPError. + """Request retry configuration that is passed to `stamina.retry`. Applies to httpx.HTTPError. Uses defaults from `stamina.retry` except for attempts: by default 3 instead of 10. See more at https://stamina.hynek.me/en/stable/api.html#stamina.retry diff --git a/any_llm_client/sse.py b/any_llm_client/sse.py deleted file mode 100644 index cdd61a7..0000000 --- a/any_llm_client/sse.py +++ /dev/null @@ -1,11 +0,0 @@ -import typing - -import httpx_sse -from httpx_sse._decoders import SSEDecoder - - -async def parse_sse_events(response: typing.AsyncIterable[bytes]) -> typing.AsyncIterator[httpx_sse.ServerSentEvent]: - sse_decoder: typing.Final = SSEDecoder() - async for one_line in response: - if event := sse_decoder.decode(one_line.decode().rstrip("\n")): - yield event diff --git a/pyproject.toml b/pyproject.toml index f374081..2eeb88e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,15 +8,16 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Software Development :: Libraries", "Topic :: System :: Networking", "Typing :: Typed", ] authors = [{ name = "Lev Vereshchagin", email = "mail@vrslev.com" }] -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10" dependencies = [ "httpx-sse>=0.4.0", - "niquests>=3.11.1", + "httpx>=0.27.2", "pydantic>=2.9.2", "stamina>=24.3.0", ] @@ -26,7 +27,6 @@ dynamic = ["version"] dev = [ "anyio", "faker", - "litestar[standard]", "polyfactory", "pydantic-settings", "pytest-cov", @@ -88,6 +88,3 @@ addopts = "--cov=." skip_covered = true show_missing = true exclude_also = ["if typing.TYPE_CHECKING:"] - -[tool.coverage.run] -omit = ["tests/testing_app.py"] diff --git a/tests/conftest.py b/tests/conftest.py index 125f440..c003365 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import contextlib import typing -from unittest import mock import pytest import stamina @@ -34,29 +33,3 @@ async def consume_llm_partial_responses( ) -> list[str]: async with request_llm_partial_responses_context_manager as response_iterable: return [one_item async for one_item in response_iterable] - - -def _make_async_stream_iterable(lines: str) -> typing.Any: # noqa: ANN401 - async def iter_lines() -> typing.AsyncIterable[bytes]: - for line in lines.splitlines(): - yield line.encode() - - return iter_lines() - - -def mock_http_client(llm_client: any_llm_client.LLMClient, request_mock: mock.AsyncMock) -> any_llm_client.LLMClient: - assert hasattr(llm_client, "http_client") - llm_client.http_client = mock.Mock( - request=request_mock, - stream=mock.Mock( - return_value=mock.Mock( - __aenter__=( - mock.AsyncMock(return_value=_make_async_stream_iterable(request_mock.return_value)) - if isinstance(request_mock.return_value, str) - else request_mock - ), - __aexit__=mock.AsyncMock(return_value=None), - ) - ), - ) - return llm_client diff --git a/tests/test_http.py b/tests/test_http.py index 35cb0e2..b7e4f0e 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,40 +1,27 @@ +import copy import typing -from http import HTTPStatus -import niquests -import pytest +import httpx -from any_llm_client.http import HttpClient, HttpStatusError -from any_llm_client.retry import RequestRetryConfig +from any_llm_client.http import DEFAULT_HTTP_TIMEOUT, get_http_client_from_kwargs -BASE_URL: typing.Final = "http://127.0.0.1:8000" +class TestGetHttpClientFromKwargs: + def test_http_timeout_is_added(self) -> None: + original_kwargs: typing.Final = {"mounts": {"http://": None}} + passed_kwargs: typing.Final = copy.deepcopy(original_kwargs) + client: typing.Final = get_http_client_from_kwargs(passed_kwargs) -async def test_http_client_request_ok() -> None: - client: typing.Final = HttpClient(request_retry=RequestRetryConfig(), niquests_kwargs={}) - result: typing.Final = await client.request(niquests.Request(method="GET", url=f"{BASE_URL}/request-ok")) - assert result == b'{"ok":true}' + assert client.timeout == DEFAULT_HTTP_TIMEOUT + assert original_kwargs == passed_kwargs + def test_http_timeout_is_not_modified_if_set(self) -> None: + timeout: typing.Final = httpx.Timeout(7, connect=5, read=3) + original_kwargs: typing.Final = {"mounts": {"http://": None}, "timeout": timeout} + passed_kwargs: typing.Final = copy.deepcopy(original_kwargs) -async def test_http_client_request_rail() -> None: - client: typing.Final = HttpClient(request_retry=RequestRetryConfig(), niquests_kwargs={}) - with pytest.raises(HttpStatusError) as exc_info: - await client.request(niquests.Request(method="GET", url=f"{BASE_URL}/request-fail")) - assert exc_info.value.status_code == HTTPStatus.IM_A_TEAPOT - assert exc_info.value.content == b'{"ok":false}' + client: typing.Final = get_http_client_from_kwargs(passed_kwargs) - -async def test_http_client_stream_ok() -> None: - client: typing.Final = HttpClient(request_retry=RequestRetryConfig(), niquests_kwargs={}) - async with client.stream(niquests.Request(method="GET", url=f"{BASE_URL}/stream-ok")) as response: - result: typing.Final = [one_chunk async for one_chunk in response] - assert result == [b"ok", b"true"] - - -async def test_http_client_stream_rail() -> None: - client: typing.Final = HttpClient(request_retry=RequestRetryConfig(), niquests_kwargs={}) - with pytest.raises(HttpStatusError) as exc_info: - await client.stream(niquests.Request(method="GET", url=f"{BASE_URL}/stream-fail")).__aenter__() - assert exc_info.value.status_code == HTTPStatus.IM_A_TEAPOT - assert exc_info.value.content == b"ok\nfalse" + assert client.timeout == timeout + assert original_kwargs == passed_kwargs diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py index fb3d71a..c563623 100644 --- a/tests/test_openai_client.py +++ b/tests/test_openai_client.py @@ -1,7 +1,7 @@ import typing -from unittest import mock import faker +import httpx import pydantic import pytest from polyfactory.factories.pydantic_factory import ModelFactory @@ -15,8 +15,7 @@ OneStreamingChoice, OneStreamingChoiceDelta, ) -from any_llm_client.http import HttpStatusError -from tests.conftest import LLMFuncRequestFactory, consume_llm_partial_responses, mock_http_client +from tests.conftest import LLMFuncRequestFactory, consume_llm_partial_responses class OpenAIConfigFactory(ModelFactory[any_llm_client.OpenAIConfig]): ... @@ -25,25 +24,32 @@ class OpenAIConfigFactory(ModelFactory[any_llm_client.OpenAIConfig]): ... class TestOpenAIRequestLLMResponse: async def test_ok(self, faker: faker.Faker) -> None: expected_result: typing.Final = faker.pystr() - response: typing.Final = ChatCompletionsNotStreamingResponse( - choices=[ - OneNotStreamingChoice( - message=ChatCompletionsMessage(role=any_llm_client.MessageRole.assistant, content=expected_result) - ) - ] - ).model_dump_json() - client: typing.Final = mock_http_client( - any_llm_client.get_client(OpenAIConfigFactory.build()), mock.AsyncMock(return_value=response) + response: typing.Final = httpx.Response( + 200, + json=ChatCompletionsNotStreamingResponse( + choices=[ + OneNotStreamingChoice( + message=ChatCompletionsMessage( + role=any_llm_client.MessageRole.assistant, content=expected_result + ) + ) + ] + ).model_dump(mode="json"), ) - result: typing.Final = await client.request_llm_message(**LLMFuncRequestFactory.build()) + result: typing.Final = await any_llm_client.get_client( + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + ).request_llm_message(**LLMFuncRequestFactory.build()) assert result == expected_result async def test_fails_without_alternatives(self) -> None: - response: typing.Final = ChatCompletionsNotStreamingResponse.model_construct(choices=[]).model_dump(mode="json") - client: typing.Final = mock_http_client( - any_llm_client.get_client(OpenAIConfigFactory.build()), mock.AsyncMock(return_value=response) + response: typing.Final = httpx.Response( + 200, + json=ChatCompletionsNotStreamingResponse.model_construct(choices=[]).model_dump(mode="json"), + ) + client: typing.Final = any_llm_client.get_client( + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -71,7 +77,9 @@ async def test_ok(self, faker: faker.Faker) -> None: "Hi there. How is you", "Hi there. How is your day?", ] - response: typing.Final = ( + config: typing.Final = OpenAIConfigFactory.build() + func_request: typing.Final = LLMFuncRequestFactory.build() + response_content: typing.Final = ( "\n\n".join( "data: " + ChatCompletionsStreamingEvent(choices=[OneStreamingChoice(delta=one_message)]).model_dump_json() @@ -79,22 +87,24 @@ async def test_ok(self, faker: faker.Faker) -> None: ) + f"\n\ndata: [DONE]\n\ndata: {faker.pystr()}\n\n" ) - client: typing.Final = mock_http_client( - any_llm_client.get_client(OpenAIConfigFactory.build()), mock.AsyncMock(return_value=response) + response: typing.Final = httpx.Response( + 200, headers={"Content-Type": "text/event-stream"}, content=response_content ) + client: typing.Final = any_llm_client.get_client(config, transport=httpx.MockTransport(lambda _: response)) - result: typing.Final = await consume_llm_partial_responses( - client.stream_llm_partial_messages(**LLMFuncRequestFactory.build()) - ) + result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request)) assert result == expected_result async def test_fails_without_alternatives(self) -> None: - response: typing.Final = ( + response_content: typing.Final = ( f"data: {ChatCompletionsStreamingEvent.model_construct(choices=[]).model_dump_json()}\n\n" ) - client: typing.Final = mock_http_client( - any_llm_client.get_client(OpenAIConfigFactory.build()), mock.AsyncMock(return_value=response) + response: typing.Final = httpx.Response( + 200, headers={"Content-Type": "text/event-stream"}, content=response_content + ) + client: typing.Final = any_llm_client.get_client( + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -105,9 +115,8 @@ class TestOpenAILLMErrors: @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("status_code", [400, 500]) async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: - client: typing.Final = mock_http_client( - any_llm_client.get_client(OpenAIConfigFactory.build()), - mock.AsyncMock(side_effect=HttpStatusError(status_code=status_code, content=b"")), + client: typing.Final = any_llm_client.get_client( + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code)) ) coroutine: typing.Final = ( @@ -128,10 +137,10 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> b'{"object":"error","message":"This model\'s maximum context length is 16384 tokens. However, you requested 100000 tokens in the messages, Please reduce the length of the messages.","type":"BadRequestError","param":null,"code":400}', # noqa: E501 ], ) - async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes) -> None: - client: typing.Final = mock_http_client( - any_llm_client.get_client(OpenAIConfigFactory.build()), - mock.AsyncMock(side_effect=HttpStatusError(status_code=400, content=content)), + async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes | None) -> None: + response: typing.Final = httpx.Response(400, content=content) + client: typing.Final = any_llm_client.get_client( + OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) coroutine: typing.Final = ( diff --git a/tests/test_static.py b/tests/test_static.py index 4045573..5ec106f 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -10,8 +10,6 @@ import any_llm_client from any_llm_client.clients.openai import ChatCompletionsRequest from any_llm_client.clients.yandexgpt import YandexGPTRequest -from any_llm_client.http import HttpClient -from any_llm_client.retry import RequestRetryConfig from tests.conftest import LLMFuncRequest @@ -49,12 +47,6 @@ def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None: assert all(annotations == all_annotations[0] for annotations in all_annotations) -def test_proxies_are_set_on_http_client(faker: faker.Faker) -> None: - proxies: typing.Final = faker.pydict() - http_client: typing.Final = HttpClient(request_retry=RequestRetryConfig(), niquests_kwargs={"proxies": proxies}) - assert http_client.client.proxies == proxies - - @pytest.mark.parametrize("model_type", [YandexGPTRequest, ChatCompletionsRequest]) def test_dumped_llm_request_payload_dump_has_extra_data(model_type: type[pydantic.BaseModel]) -> None: extra: typing.Final = {"hi": "there", "hi-hi": "there-there"} diff --git a/tests/test_yandexgpt_client.py b/tests/test_yandexgpt_client.py index 8d7c079..8160ac0 100644 --- a/tests/test_yandexgpt_client.py +++ b/tests/test_yandexgpt_client.py @@ -1,15 +1,14 @@ import typing -from unittest import mock import faker +import httpx import pydantic import pytest from polyfactory.factories.pydantic_factory import ModelFactory import any_llm_client from any_llm_client.clients.yandexgpt import YandexGPTAlternative, YandexGPTResponse, YandexGPTResult -from any_llm_client.http import HttpStatusError -from tests.conftest import LLMFuncRequestFactory, consume_llm_partial_responses, mock_http_client +from tests.conftest import LLMFuncRequestFactory, consume_llm_partial_responses class YandexGPTConfigFactory(ModelFactory[any_llm_client.YandexGPTConfig]): ... @@ -18,25 +17,27 @@ class YandexGPTConfigFactory(ModelFactory[any_llm_client.YandexGPTConfig]): ... class TestYandexGPTRequestLLMResponse: async def test_ok(self, faker: faker.Faker) -> None: expected_result: typing.Final = faker.pystr() - response: typing.Final = YandexGPTResponse( - result=YandexGPTResult( - alternatives=[YandexGPTAlternative(message=any_llm_client.AssistantMessage(expected_result))] - ) - ).model_dump_json() - client: typing.Final = mock_http_client( - any_llm_client.get_client(YandexGPTConfigFactory.build()), mock.AsyncMock(return_value=response) + response: typing.Final = httpx.Response( + 200, + json=YandexGPTResponse( + result=YandexGPTResult( + alternatives=[YandexGPTAlternative(message=any_llm_client.AssistantMessage(expected_result))] + ) + ).model_dump(mode="json"), ) - result: typing.Final = await client.request_llm_message(**LLMFuncRequestFactory.build()) + result: typing.Final = await any_llm_client.get_client( + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) + ).request_llm_message(**LLMFuncRequestFactory.build()) assert result == expected_result async def test_fails_without_alternatives(self) -> None: - response: typing.Final = YandexGPTResponse( - result=YandexGPTResult.model_construct(alternatives=[]) - ).model_dump_json() - client: typing.Final = mock_http_client( - any_llm_client.get_client(YandexGPTConfigFactory.build()), mock.AsyncMock(return_value=response) + response: typing.Final = httpx.Response( + 200, json=YandexGPTResponse(result=YandexGPTResult.model_construct(alternatives=[])).model_dump(mode="json") + ) + client: typing.Final = any_llm_client.get_client( + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -46,8 +47,9 @@ async def test_fails_without_alternatives(self) -> None: class TestYandexGPTRequestLLMPartialResponses: async def test_ok(self, faker: faker.Faker) -> None: expected_result: typing.Final = faker.pylist(value_types=[str]) + config: typing.Final = YandexGPTConfigFactory.build() func_request: typing.Final = LLMFuncRequestFactory.build() - response: typing.Final = ( + response_content: typing.Final = ( "\n".join( YandexGPTResponse( result=YandexGPTResult( @@ -58,20 +60,24 @@ async def test_ok(self, faker: faker.Faker) -> None: ) + "\n" ) - client: typing.Final = mock_http_client( - any_llm_client.get_client(YandexGPTConfigFactory.build()), mock.AsyncMock(return_value=response) - ) + response: typing.Final = httpx.Response(200, content=response_content) - result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request)) + result: typing.Final = await consume_llm_partial_responses( + any_llm_client.get_client( + config, transport=httpx.MockTransport(lambda _: response) + ).stream_llm_partial_messages(**func_request) + ) assert result == expected_result async def test_fails_without_alternatives(self) -> None: - response: typing.Final = ( + response_content: typing.Final = ( YandexGPTResponse(result=YandexGPTResult.model_construct(alternatives=[])).model_dump_json() + "\n" ) - client: typing.Final = mock_http_client( - any_llm_client.get_client(YandexGPTConfigFactory.build()), mock.AsyncMock(return_value=response) + response: typing.Final = httpx.Response(200, content=response_content) + + client: typing.Final = any_llm_client.get_client( + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) with pytest.raises(pydantic.ValidationError): @@ -81,10 +87,9 @@ async def test_fails_without_alternatives(self) -> None: class TestYandexGPTLLMErrors: @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("status_code", [400, 500]) - async def test_fails_with_unknown_error(self, faker: faker.Faker, stream: bool, status_code: int) -> None: - client: typing.Final = mock_http_client( - any_llm_client.get_client(YandexGPTConfigFactory.build()), - mock.AsyncMock(side_effect=HttpStatusError(status_code=status_code, content=faker.pystr().encode())), + async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None: + client: typing.Final = any_llm_client.get_client( + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: httpx.Response(status_code)) ) coroutine: typing.Final = ( @@ -99,16 +104,16 @@ async def test_fails_with_unknown_error(self, faker: faker.Faker, stream: bool, @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize( - "content", + "response_content", [ b"...folder_id=1111: number of input tokens must be no more than 8192, got 28498...", b"...folder_id=1111: text length is 349354, which is outside the range (0, 100000]...", ], ) - async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes) -> None: - client: typing.Final = mock_http_client( - any_llm_client.get_client(YandexGPTConfigFactory.build()), - mock.AsyncMock(side_effect=HttpStatusError(status_code=400, content=content)), + async def test_fails_with_out_of_tokens_error(self, stream: bool, response_content: bytes | None) -> None: + response: typing.Final = httpx.Response(400, content=response_content) + client: typing.Final = any_llm_client.get_client( + YandexGPTConfigFactory.build(), transport=httpx.MockTransport(lambda _: response) ) coroutine: typing.Final = ( diff --git a/tests/testing_app.py b/tests/testing_app.py deleted file mode 100644 index 66dc2ff..0000000 --- a/tests/testing_app.py +++ /dev/null @@ -1,28 +0,0 @@ -import typing - -import litestar -import litestar.background_tasks -from litestar.response import Stream - - -@litestar.get("/request-ok") -async def request_ok() -> dict[str, typing.Any]: - return {"ok": True} - - -@litestar.get("/request-fail", status_code=418) -async def request_fail() -> dict[str, typing.Any]: - return {"ok": False} - - -@litestar.get("/stream-ok") -async def stream_ok() -> Stream: - return Stream("ok\ntrue") - - -@litestar.get("/stream-fail") -async def stream_fail() -> Stream: - return Stream("ok\nfalse", status_code=418) - - -app = litestar.Litestar(route_handlers=[request_ok, request_fail, stream_ok, stream_fail])