Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move back to HTTPX #13

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: extractions/setup-just@v2
- uses: astral-sh/setup-uv@v4
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
Expand Down
10 changes: 0 additions & 10 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,8 @@ lint:
uv run --group lint ruff format
uv run --group lint mypy .

_test-no-http *args:
uv run pytest --ignore tests/test_http.py {{ args }}

test *args:
#!/bin/bash
uv run litestar --app tests.testing_app:app run &
APP_PID=$!
uv run pytest {{ args }}
TEST_RESULT=$?
kill $APP_PID
wait $APP_PID 2>/dev/null
exit $TEST_RESULT

publish:
rm -rf dist
Expand Down
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,23 @@ async with any_llm_client.OpenAIClient(config, ...) as client:
#### Timeouts, proxy & other HTTP settings


Pass custom [niquests](https://niquests.readthedocs.io) kwargs to `any_llm_client.get_client()`:
Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`:

```python
import urllib3
import httpx

import any_llm_client


async with any_llm_client.get_client(
...,
proxies={"https://api.openai.com": "http://localhost:8030"},
timeout=urllib3.Timeout(total=10.0, connect=5.0),
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
timeout=httpx.Timeout(None, connect=5.0),
) as client:
...
```

`timeout` and `proxies` parameters are special cased here: `niquests.AsyncSession` doesn't receive them by default.

Default timeout is `urllib3.Timeout(total=None, connect=5.0)`.
Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unlimited on read, write or pool).

#### Retries

Expand Down
68 changes: 39 additions & 29 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from http import HTTPStatus

import annotated_types
import niquests
import httpx
import httpx_sse
import pydantic
import typing_extensions

Expand All @@ -19,9 +20,8 @@
OutOfTokensOrSymbolsError,
UserMessage,
)
from any_llm_client.http import HttpClient, HttpStatusError
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig
from any_llm_client.sse import parse_sse_events


OPENAI_AUTH_TOKEN_ENV_NAME: typing.Final = "ANY_LLM_CLIENT_OPENAI_AUTH_TOKEN"
Expand Down Expand Up @@ -99,34 +99,31 @@ def _make_user_assistant_alternate_messages(
yield ChatCompletionsMessage(role=current_message_role, content="\n\n".join(current_message_content_chunks))


def _handle_status_error(error: HttpStatusError) -> typing.NoReturn:
if (
error.status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in error.content
): # vLLM
raise OutOfTokensOrSymbolsError(response_content=error.content)
raise LLMError(response_content=error.content)
def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn:
if status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in content: # vLLM
raise OutOfTokensOrSymbolsError(response_content=content)
raise LLMError(response_content=content)


@dataclasses.dataclass(slots=True, init=False)
class OpenAIClient(LLMClient):
config: OpenAIConfig
http_client: HttpClient
httpx_client: httpx.AsyncClient
request_retry: RequestRetryConfig

def __init__(
self,
config: OpenAIConfig,
*,
request_retry: RequestRetryConfig | None = None,
**niquests_kwargs: typing.Any, # noqa: ANN401
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.http_client = HttpClient(
request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs
)
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)

def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request:
return niquests.Request(
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
Expand Down Expand Up @@ -155,17 +152,24 @@ async def request_llm_message(
**extra or {},
).model_dump(mode="json")
try:
response: typing.Final = await self.http_client.request(self._build_request(payload))
except HttpStatusError as exception:
_handle_status_error(exception)
return ChatCompletionsNotStreamingResponse.model_validate_json(response).choices[0].message.content
response: typing.Final = await make_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
try:
return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content
finally:
await response.aclose()

async def _iter_partial_responses(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]:
async def _iter_partial_responses(self, response: httpx.Response) -> typing.AsyncIterable[str]:
text_chunks: typing.Final = []
async for one_event in parse_sse_events(response):
if one_event.data == "[DONE]":
async for event in httpx_sse.EventSource(response).aiter_sse():
if event.data == "[DONE]":
break
validated_response = ChatCompletionsStreamingEvent.model_validate_json(one_event.data)
validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data)
if not (one_chunk := validated_response.choices[0].delta.content):
continue
text_chunks.append(one_chunk)
Expand All @@ -183,13 +187,19 @@ async def stream_llm_partial_messages(
**extra or {},
).model_dump(mode="json")
try:
async with self.http_client.stream(request=self._build_request(payload)) as response:
async with make_streaming_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_partial_responses(response)
except HttpStatusError as exception:
_handle_status_error(exception)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)

async def __aenter__(self) -> typing_extensions.Self:
await self.http_client.__aenter__()
await self.httpx_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -198,4 +208,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
66 changes: 38 additions & 28 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from http import HTTPStatus

import annotated_types
import niquests
import httpx
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage
from any_llm_client.http import HttpClient, HttpStatusError
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig


Expand Down Expand Up @@ -61,34 +61,34 @@ class YandexGPTResponse(pydantic.BaseModel):
result: YandexGPTResult


def _handle_status_error(error: HttpStatusError) -> typing.NoReturn:
if error.status_code == HTTPStatus.BAD_REQUEST and (
b"number of input tokens must be no more than" in error.content
or (b"text length is" in error.content and b"which is outside the range" in error.content)
def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn:
if status_code == HTTPStatus.BAD_REQUEST and (
b"number of input tokens must be no more than" in content
or (b"text length is" in content and b"which is outside the range" in content)
):
raise OutOfTokensOrSymbolsError(response_content=error.content)
raise LLMError(response_content=error.content)
raise OutOfTokensOrSymbolsError(response_content=content)
raise LLMError(response_content=content)


@dataclasses.dataclass(slots=True, init=False)
class YandexGPTClient(LLMClient):
config: YandexGPTConfig
http_client: HttpClient
httpx_client: httpx.AsyncClient
request_retry: RequestRetryConfig

def __init__(
self,
config: YandexGPTConfig,
*,
request_retry: RequestRetryConfig | None = None,
**niquests_kwargs: typing.Any, # noqa: ANN401
**httpx_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.http_client = HttpClient(
request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs
)
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)

def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request:
return niquests.Request(
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
Expand Down Expand Up @@ -121,14 +121,18 @@ async def request_llm_message(
)

try:
response: typing.Final = await self.http_client.request(self._build_request(payload))
except HttpStatusError as exception:
_handle_status_error(exception)

return YandexGPTResponse.model_validate_json(response).result.alternatives[0].message.text

async def _iter_completion_messages(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]:
async for one_line in response:
response: typing.Final = await make_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)

return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text

async def _iter_completion_messages(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async for one_line in response.aiter_lines():
validated_response = YandexGPTResponse.model_validate_json(one_line)
yield validated_response.result.alternatives[0].message.text

Expand All @@ -141,13 +145,19 @@ async def stream_llm_partial_messages(
)

try:
async with self.http_client.stream(request=self._build_request(payload)) as response:
async with make_streaming_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_completion_messages(response)
except HttpStatusError as exception:
_handle_status_error(exception)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)

async def __aenter__(self) -> typing_extensions.Self:
await self.http_client.__aenter__()
await self.httpx_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -156,4 +166,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
Loading
Loading